This is an automated email from the ASF dual-hosted git repository.

morrysnow pushed a commit to branch 2.0.1-rc04-patch
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/2.0.1-rc04-patch by this push:
     new c1aae5e984 [enhancement](nereids)remove useless cast for floatlike 
type (#23621)
c1aae5e984 is described below

commit c1aae5e984990dba5494a5e79ff831122b229568
Author: starocean999 <40539150+starocean...@users.noreply.github.com>
AuthorDate: Wed Aug 30 19:00:16 2023 +0800

    [enhancement](nereids)remove useless cast for floatlike type (#23621)
    
    convert cast(c1 AS double) > 2.0 to c1 >= 2 (c1 is integer like type)
---
 .../rules/SimplifyComparisonPredicate.java         | 155 ++++++++++---
 .../test_simplify_comparison.groovy                | 248 +++++++++++++++++++++
 2 files changed, 376 insertions(+), 27 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java
index c66e27e8b2..19574f8f16 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java
@@ -31,13 +31,21 @@ import org.apache.doris.nereids.trees.expressions.IsNull;
 import org.apache.doris.nereids.trees.expressions.LessThan;
 import org.apache.doris.nereids.trees.expressions.LessThanEqual;
 import org.apache.doris.nereids.trees.expressions.NullSafeEqual;
+import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
 import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
 import org.apache.doris.nereids.trees.expressions.literal.DateLiteral;
 import org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral;
 import org.apache.doris.nereids.trees.expressions.literal.DateTimeV2Literal;
 import org.apache.doris.nereids.trees.expressions.literal.DateV2Literal;
 import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal;
+import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.FloatLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.Literal;
 import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
 import org.apache.doris.nereids.types.BooleanType;
 import org.apache.doris.nereids.types.DateTimeType;
 import org.apache.doris.nereids.types.DateTimeV2Type;
@@ -46,9 +54,15 @@ import org.apache.doris.nereids.types.DateV2Type;
 import org.apache.doris.nereids.types.DecimalV3Type;
 import org.apache.doris.nereids.types.coercion.DateLikeType;
 
+import com.google.common.base.Preconditions;
+
+import java.math.BigDecimal;
+import java.math.RoundingMode;
+
 /**
  * simplify comparison
  * such as: cast(c1 as DateV2) >= DateV2Literal --> c1 >= DateLiteral
+ *          cast(c1 AS double) > 2.0 --> c1 >= 2 (c1 is integer like type)
  */
 public class SimplifyComparisonPredicate extends AbstractExpressionRewriteRule 
{
 
@@ -65,6 +79,11 @@ public class SimplifyComparisonPredicate extends 
AbstractExpressionRewriteRule {
         Expression left = rewrite(cp.left(), context);
         Expression right = rewrite(cp.right(), context);
 
+        // float like type: float, double
+        if (left.getDataType().isFloatLikeType() && 
right.getDataType().isFloatLikeType()) {
+            return processFloatLikeTypeCoercion(cp, left, right);
+        }
+
         // decimalv3 type
         if (left.getDataType() instanceof DecimalV3Type
                 && right.getDataType() instanceof DecimalV3Type) {
@@ -194,6 +213,26 @@ public class SimplifyComparisonPredicate extends 
AbstractExpressionRewriteRule {
         }
     }
 
+    private Expression processFloatLikeTypeCoercion(ComparisonPredicate 
comparisonPredicate,
+            Expression left, Expression right) {
+        if (left instanceof Literal) {
+            comparisonPredicate = comparisonPredicate.commute();
+            Expression temp = left;
+            left = right;
+            right = temp;
+        }
+
+        if (left instanceof Cast && 
left.child(0).getDataType().isIntegerLikeType()
+                && (right instanceof DoubleLiteral || right instanceof 
FloatLiteral)) {
+            Cast cast = (Cast) left;
+            left = cast.child();
+            BigDecimal literal = new BigDecimal(((Literal) 
right).getStringValue());
+            return processIntegerDecimalLiteralComparison(comparisonPredicate, 
left, literal);
+        } else {
+            return comparisonPredicate;
+        }
+    }
+
     private Expression processDecimalV3TypeCoercion(ComparisonPredicate 
comparisonPredicate,
             Expression left, Expression right) {
         if (left instanceof DecimalV3Literal) {
@@ -203,51 +242,113 @@ public class SimplifyComparisonPredicate extends 
AbstractExpressionRewriteRule {
             right = temp;
         }
 
-        if (left instanceof Cast && 
left.child(0).getDataType().isDecimalV3Type()
-                && right instanceof DecimalV3Literal) {
+        if (left instanceof Cast && right instanceof DecimalV3Literal) {
             Cast cast = (Cast) left;
             left = cast.child();
             DecimalV3Literal literal = (DecimalV3Literal) right;
-            if (((DecimalV3Type) left.getDataType())
-                    .getScale() < ((DecimalV3Type) 
literal.getDataType()).getScale()) {
-                int toScale = ((DecimalV3Type) left.getDataType()).getScale();
-                if (comparisonPredicate instanceof EqualTo) {
-                    try {
-                        return comparisonPredicate.withChildren(left, new 
DecimalV3Literal(
-                                (DecimalV3Type) left.getDataType(), 
literal.getValue().setScale(toScale)));
-                    } catch (ArithmeticException e) {
-                        if (left.nullable()) {
-                            // TODO: the ideal way is to return an If expr 
like:
-                            // return new If(new IsNull(left), new 
NullLiteral(BooleanType.INSTANCE),
-                            // BooleanLiteral.of(false));
-                            // but current fold constant rule can't handle 
such complex expr with null literal
-                            // before supporting complex conjuncts with null 
literal folding rules,
-                            // we use a trick way like this:
-                            return new And(new IsNull(left), new 
NullLiteral(BooleanType.INSTANCE));
-                        } else {
+            if (left.getDataType().isDecimalV3Type()) {
+                if (((DecimalV3Type) left.getDataType())
+                        .getScale() < ((DecimalV3Type) 
literal.getDataType()).getScale()) {
+                    int toScale = ((DecimalV3Type) 
left.getDataType()).getScale();
+                    if (comparisonPredicate instanceof EqualTo) {
+                        try {
+                            return comparisonPredicate.withChildren(left,
+                                    new DecimalV3Literal((DecimalV3Type) 
left.getDataType(),
+                                            
literal.getValue().setScale(toScale)));
+                        } catch (ArithmeticException e) {
+                            if (left.nullable()) {
+                                // TODO: the ideal way is to return an If expr 
like:
+                                // return new If(new IsNull(left), new 
NullLiteral(BooleanType.INSTANCE),
+                                // BooleanLiteral.of(false));
+                                // but current fold constant rule can't handle 
such complex expr with null literal
+                                // before supporting complex conjuncts with 
null literal folding rules,
+                                // we use a trick way like this:
+                                return new And(new IsNull(left),
+                                        new NullLiteral(BooleanType.INSTANCE));
+                            } else {
+                                return BooleanLiteral.of(false);
+                            }
+                        }
+                    } else if (comparisonPredicate instanceof NullSafeEqual) {
+                        try {
+                            return comparisonPredicate.withChildren(left,
+                                    new DecimalV3Literal((DecimalV3Type) 
left.getDataType(),
+                                            
literal.getValue().setScale(toScale)));
+                        } catch (ArithmeticException e) {
                             return BooleanLiteral.of(false);
                         }
+                    } else if (comparisonPredicate instanceof GreaterThan
+                            || comparisonPredicate instanceof LessThanEqual) {
+                        return comparisonPredicate.withChildren(left, 
literal.roundFloor(toScale));
+                    } else if (comparisonPredicate instanceof LessThan
+                            || comparisonPredicate instanceof 
GreaterThanEqual) {
+                        return comparisonPredicate.withChildren(left,
+                                literal.roundCeiling(toScale));
                     }
-                } else if (comparisonPredicate instanceof NullSafeEqual) {
-                    try {
-                        return comparisonPredicate.withChildren(left, new 
DecimalV3Literal(
-                                (DecimalV3Type) left.getDataType(), 
literal.getValue().setScale(toScale)));
-                    } catch (ArithmeticException e) {
+                }
+            } else if (left.getDataType().isIntegerLikeType()) {
+                return 
processIntegerDecimalLiteralComparison(comparisonPredicate, left,
+                        literal.getValue());
+            }
+        }
+
+        return comparisonPredicate;
+    }
+
+    private Expression processIntegerDecimalLiteralComparison(
+            ComparisonPredicate comparisonPredicate, Expression left, 
BigDecimal literal) {
+        // we only process isIntegerLikeType, which are tinyint, smallint, 
int, bigint
+        if (literal.compareTo(new BigDecimal(Long.MAX_VALUE)) <= 0) {
+            if (literal.scale() > 0) {
+                if (comparisonPredicate instanceof EqualTo) {
+                    if (left.nullable()) {
+                        // TODO: the ideal way is to return an If expr like:
+                        // return new If(new IsNull(left), new 
NullLiteral(BooleanType.INSTANCE),
+                        // BooleanLiteral.of(false));
+                        // but current fold constant rule can't handle such 
complex expr with null literal
+                        // before supporting complex conjuncts with null 
literal folding rules,
+                        // we use a trick way like this:
+                        return new And(new IsNull(left), new 
NullLiteral(BooleanType.INSTANCE));
+                    } else {
                         return BooleanLiteral.of(false);
                     }
+                } else if (comparisonPredicate instanceof NullSafeEqual) {
+                    return BooleanLiteral.of(false);
                 } else if (comparisonPredicate instanceof GreaterThan
                         || comparisonPredicate instanceof LessThanEqual) {
-                    return comparisonPredicate.withChildren(left, 
literal.roundFloor(toScale));
+                    return comparisonPredicate.withChildren(left,
+                            convertDecimalToIntegerLikeLiteral(
+                                    literal.setScale(0, RoundingMode.FLOOR)));
                 } else if (comparisonPredicate instanceof LessThan
                         || comparisonPredicate instanceof GreaterThanEqual) {
-                    return comparisonPredicate.withChildren(left, 
literal.roundCeiling(toScale));
+                    return comparisonPredicate.withChildren(left,
+                            convertDecimalToIntegerLikeLiteral(
+                                    literal.setScale(0, 
RoundingMode.CEILING)));
                 }
+            } else {
+                return comparisonPredicate.withChildren(left,
+                        convertDecimalToIntegerLikeLiteral(literal));
             }
         }
-
         return comparisonPredicate;
     }
 
+    private IntegerLikeLiteral convertDecimalToIntegerLikeLiteral(BigDecimal 
decimal) {
+        Preconditions.checkArgument(
+                decimal.scale() == 0 && decimal.compareTo(new 
BigDecimal(Long.MAX_VALUE)) <= 0,
+                "decimal literal must have 0 scale and smaller than 
Long.MAX_VALUE");
+        long val = decimal.longValue();
+        if (val <= Byte.MAX_VALUE) {
+            return new TinyIntLiteral((byte) val);
+        } else if (val <= Short.MAX_VALUE) {
+            return new SmallIntLiteral((short) val);
+        } else if (val <= Integer.MAX_VALUE) {
+            return new IntegerLiteral((int) val);
+        } else {
+            return new BigIntLiteral(val);
+        }
+    }
+
     private Expression migrateCastToDateTime(Cast cast) {
         //cast( cast(v as date) as datetime) if v is datetime, set left = v
         if (cast.child() instanceof Cast
diff --git 
a/regression-test/suites/nereids_syntax_p0/test_simplify_comparison.groovy 
b/regression-test/suites/nereids_syntax_p0/test_simplify_comparison.groovy
index 53c0ff9a12..4b3cd3bdca 100644
--- a/regression-test/suites/nereids_syntax_p0/test_simplify_comparison.groovy
+++ b/regression-test/suites/nereids_syntax_p0/test_simplify_comparison.groovy
@@ -72,4 +72,252 @@ suite("test_simplify_comparison") {
     }
 
     sql "select cast('1234' as decimalv3(18,4)) > 2000;"
+
+    sql 'drop table if exists simple_test_table_t;'
+    sql """CREATE TABLE IF NOT EXISTS `simple_test_table_t` (
+            a tinyint,
+            b smallint,
+            c int,
+            d bigint,
+            e largeint
+            ) ENGINE=OLAP
+            UNIQUE KEY (`a`)
+            DISTRIBUTED BY HASH(`a`) BUCKETS 120
+            PROPERTIES (
+            "replication_num" = "1",
+            "in_memory" = "false",
+            "compression" = "LZ4"
+            );"""
+
+    explain {
+        sql "verbose select * from simple_test_table_t where a = cast(1.0 as 
double) and b = cast(1.0 as double) and c = cast(1.0 as double) and d = 
cast(1.0 as double);"
+        notContains "CAST"
+    }
+
+    explain {
+        sql "verbose select * from simple_test_table_t where e = cast(1.0 as 
double);"
+        contains "CAST"
+    }
+
+    explain {
+        sql "verbose select * from simple_test_table_t where a > cast(1.0 as 
double) and b > cast(1.0 as double) and c > cast(1.0 as double) and d > 
cast(1.0 as double);"
+        notContains "CAST"
+    }
+
+    explain {
+        sql "verbose select * from simple_test_table_t where e > cast(1.0 as 
double);"
+        contains "CAST"
+    }
+
+    explain {
+        sql "verbose select * from simple_test_table_t where a < cast(1.0 as 
double) and b < cast(1.0 as double) and c < cast(1.0 as double) and d < 
cast(1.0 as double);"
+        notContains "CAST"
+    }
+
+    explain {
+        sql "verbose select * from simple_test_table_t where e < cast(1.0 as 
double);"
+        contains "CAST"
+    }
+
+    explain {
+        sql "verbose select * from simple_test_table_t where a >= cast(1.0 as 
double) and b >= cast(1.0 as double) and c >= cast(1.0 as double) and d >= 
cast(1.0 as double);"
+        notContains "CAST"
+    }
+
+    explain {
+        sql "verbose select * from simple_test_table_t where e >= cast(1.0 as 
double);"
+        contains "CAST"
+    }
+
+    explain {
+        sql "verbose select * from simple_test_table_t where a <= cast(1.0 as 
double) and b <= cast(1.0 as double) and c <= cast(1.0 as double) and d <= 
cast(1.0 as double);"
+        notContains "CAST"
+    }
+
+    explain {
+        sql "verbose select * from simple_test_table_t where e <= cast(1.0 as 
double);"
+        contains "CAST"
+    }
+
+    explain {
+        sql "verbose select * from simple_test_table_t where a = cast(1.1 as 
double) and b = cast(1.1 as double) and c = cast(1.1 as double) and d = 
cast(1.1 as double);"
+        contains "a[#0] IS NULL"
+        contains "b[#1] IS NULL"
+        contains "c[#2] IS NULL"
+        contains "d[#3] IS NULL"
+        contains "AND NULL"
+    }
+
+    explain {
+        sql "verbose select * from simple_test_table_t where e = cast(1.1 as 
double);"
+        contains "CAST(e[#4] AS DOUBLE) = 1.1"
+    }
+
+    explain {
+        sql "verbose select * from simple_test_table_t where a > cast(1.1 as 
double) and b > cast(1.1 as double) and c > cast(1.1 as double) and d > 
cast(1.1 as double);"
+        contains "a[#0] > 1"
+        contains "b[#1] > 1"
+        contains "c[#2] > 1"
+        contains "d[#3] > 1"
+    }
+
+    explain {
+        sql "verbose select * from simple_test_table_t where e > cast(1.1 as 
double);"
+        contains "CAST(e[#4] AS DOUBLE) > 1.1"
+    }
+
+    explain {
+        sql "verbose select * from simple_test_table_t where a < cast(1.1 as 
double) and b < cast(1.1 as double) and c < cast(1.1 as double) and d < 
cast(1.1 as double);"
+        contains "a[#0] < 2"
+        contains "b[#1] < 2"
+        contains "c[#2] < 2"
+        contains "d[#3] < 2"
+    }
+
+    explain {
+        sql "verbose select * from simple_test_table_t where e < cast(1.1 as 
double);"
+        contains "CAST(e[#4] AS DOUBLE) < 1.1"
+    }
+
+    explain {
+        sql "verbose select * from simple_test_table_t where a >= cast(1.1 as 
double) and b >= cast(1.1 as double) and c >= cast(1.1 as double) and d >= 
cast(1.1 as double);"
+        contains "a[#0] >= 2"
+        contains "b[#1] >= 2"
+        contains "c[#2] >= 2"
+        contains "d[#3] >= 2"
+    }
+
+    explain {
+        sql "verbose select * from simple_test_table_t where e >= cast(1.1 as 
double);"
+        contains "CAST(e[#4] AS DOUBLE) >= 1.1"
+    }
+
+    explain {
+        sql "verbose select * from simple_test_table_t where a <= cast(1.1 as 
double) and b <= cast(1.1 as double) and c <= cast(1.1 as double) and d <= 
cast(1.1 as double);"
+        contains "a[#0] <= 1"
+        contains "b[#1] <= 1"
+        contains "c[#2] <= 1"
+        contains "d[#3] <= 1"
+    }
+
+    explain {
+        sql "verbose select * from simple_test_table_t where e <= cast(1.1 as 
double);"
+        contains "CAST(e[#4] AS DOUBLE) <= 1.1"
+    }
+
+    explain {
+        sql "verbose select * from simple_test_table_t where a = 1.0 and b = 
1.0 and c = 1.0 and d = 1.0;"
+        notContains "CAST"
+    }
+
+    explain {
+        sql "verbose select * from simple_test_table_t where e = 1.0;"
+        contains "CAST"
+    }
+
+    explain {
+        sql "verbose select * from simple_test_table_t where a > 1.0 and b > 
1.0 and c > 1.0 and d > 1.0;"
+        notContains "CAST"
+    }
+
+    explain {
+        sql "verbose select * from simple_test_table_t where e > 1.0;"
+        contains "CAST"
+    }
+
+    explain {
+        sql "verbose select * from simple_test_table_t where a < 1.0 and b < 
1.0 and c < 1.0 and d < 1.0;"
+        notContains "CAST"
+    }
+
+    explain {
+        sql "verbose select * from simple_test_table_t where e < 1.0;"
+        contains "CAST"
+    }
+
+    explain {
+        sql "verbose select * from simple_test_table_t where a >= 1.0 and b >= 
1.0 and c >= 1.0 and d >= 1.0;"
+        notContains "CAST"
+    }
+
+    explain {
+        sql "verbose select * from simple_test_table_t where e >= 1.0;"
+        contains "CAST"
+    }
+
+    explain {
+        sql "verbose select * from simple_test_table_t where a <= 1.0 and b <= 
1.0 and c <= 1.0 and d <= 1.0;"
+        notContains "CAST"
+    }
+
+    explain {
+        sql "verbose select * from simple_test_table_t where e <= 1.0;"
+        contains "CAST"
+    }
+
+    explain {
+        sql "verbose select * from simple_test_table_t where a = 1.1 and b = 
1.1 and c = 1.1 and d = 1.1;"
+        contains "a[#0] IS NULL"
+        contains "b[#1] IS NULL"
+        contains "c[#2] IS NULL"
+        contains "d[#3] IS NULL"
+        contains "AND NULL"
+    }
+
+    explain {
+        sql "verbose select * from simple_test_table_t where e = 1.1;"
+        contains "CAST(e[#4] AS DOUBLE) = 1.1"
+    }
+
+    explain {
+        sql "verbose select * from simple_test_table_t where a > 1.1 and b > 
1.1 and c > 1.1 and d > 1.1;"
+        contains "a[#0] > 1"
+        contains "b[#1] > 1"
+        contains "c[#2] > 1"
+        contains "d[#3] > 1"
+    }
+
+    explain {
+        sql "verbose select * from simple_test_table_t where e > 1.1;"
+        contains "CAST(e[#4] AS DOUBLE) > 1.1"
+    }
+
+    explain {
+        sql "verbose select * from simple_test_table_t where a < 1.1 and b < 
1.1 and c < 1.1 and d < 1.1;"
+        contains "a[#0] < 2"
+        contains "b[#1] < 2"
+        contains "c[#2] < 2"
+        contains "d[#3] < 2"
+    }
+
+    explain {
+        sql "verbose select * from simple_test_table_t where e < 1.1;"
+        contains "CAST(e[#4] AS DOUBLE) < 1.1"
+    }
+
+    explain {
+        sql "verbose select * from simple_test_table_t where a >= 1.1 and b >= 
1.1 and c >= 1.1 and d >= 1.1;"
+        contains "a[#0] >= 2"
+        contains "b[#1] >= 2"
+        contains "c[#2] >= 2"
+        contains "d[#3] >= 2"
+    }
+
+    explain {
+        sql "verbose select * from simple_test_table_t where e >= 1.1;"
+        contains "CAST(e[#4] AS DOUBLE) >= 1.1"
+    }
+
+    explain {
+        sql "verbose select * from simple_test_table_t where a <= 1.1 and b <= 
1.1 and c <= 1.1 and d <= 1.1;"
+        contains "a[#0] <= 1"
+        contains "b[#1] <= 1"
+        contains "c[#2] <= 1"
+        contains "d[#3] <= 1"
+    }
+
+    explain {
+        sql "verbose select * from simple_test_table_t where e <= 1.1;"
+        contains "CAST(e[#4] AS DOUBLE) <= 1.1"
+    }
 }
\ No newline at end of file


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to