This is an automated email from the ASF dual-hosted git repository. morrysnow pushed a commit to branch 2.0.1-rc04-patch in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/2.0.1-rc04-patch by this push: new c1aae5e984 [enhancement](nereids)remove useless cast for floatlike type (#23621) c1aae5e984 is described below commit c1aae5e984990dba5494a5e79ff831122b229568 Author: starocean999 <40539150+starocean...@users.noreply.github.com> AuthorDate: Wed Aug 30 19:00:16 2023 +0800 [enhancement](nereids)remove useless cast for floatlike type (#23621) convert cast(c1 AS double) > 2.0 to c1 >= 2 (c1 is integer like type) --- .../rules/SimplifyComparisonPredicate.java | 155 ++++++++++--- .../test_simplify_comparison.groovy | 248 +++++++++++++++++++++ 2 files changed, 376 insertions(+), 27 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java index c66e27e8b2..19574f8f16 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java @@ -31,13 +31,21 @@ import org.apache.doris.nereids.trees.expressions.IsNull; import org.apache.doris.nereids.trees.expressions.LessThan; import org.apache.doris.nereids.trees.expressions.LessThanEqual; import org.apache.doris.nereids.trees.expressions.NullSafeEqual; +import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral; import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; import org.apache.doris.nereids.trees.expressions.literal.DateLiteral; import org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral; import org.apache.doris.nereids.trees.expressions.literal.DateTimeV2Literal; import org.apache.doris.nereids.trees.expressions.literal.DateV2Literal; import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal; +import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral; +import org.apache.doris.nereids.trees.expressions.literal.FloatLiteral; +import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral; +import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; +import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; +import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral; +import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral; import org.apache.doris.nereids.types.BooleanType; import org.apache.doris.nereids.types.DateTimeType; import org.apache.doris.nereids.types.DateTimeV2Type; @@ -46,9 +54,15 @@ import org.apache.doris.nereids.types.DateV2Type; import org.apache.doris.nereids.types.DecimalV3Type; import org.apache.doris.nereids.types.coercion.DateLikeType; +import com.google.common.base.Preconditions; + +import java.math.BigDecimal; +import java.math.RoundingMode; + /** * simplify comparison * such as: cast(c1 as DateV2) >= DateV2Literal --> c1 >= DateLiteral + * cast(c1 AS double) > 2.0 --> c1 >= 2 (c1 is integer like type) */ public class SimplifyComparisonPredicate extends AbstractExpressionRewriteRule { @@ -65,6 +79,11 @@ public class SimplifyComparisonPredicate extends AbstractExpressionRewriteRule { Expression left = rewrite(cp.left(), context); Expression right = rewrite(cp.right(), context); + // float like type: float, double + if (left.getDataType().isFloatLikeType() && right.getDataType().isFloatLikeType()) { + return processFloatLikeTypeCoercion(cp, left, right); + } + // decimalv3 type if (left.getDataType() instanceof DecimalV3Type && right.getDataType() instanceof DecimalV3Type) { @@ -194,6 +213,26 @@ public class SimplifyComparisonPredicate extends AbstractExpressionRewriteRule { } } + private Expression processFloatLikeTypeCoercion(ComparisonPredicate comparisonPredicate, + Expression left, Expression right) { + if (left instanceof Literal) { + comparisonPredicate = comparisonPredicate.commute(); + Expression temp = left; + left = right; + right = temp; + } + + if (left instanceof Cast && left.child(0).getDataType().isIntegerLikeType() + && (right instanceof DoubleLiteral || right instanceof FloatLiteral)) { + Cast cast = (Cast) left; + left = cast.child(); + BigDecimal literal = new BigDecimal(((Literal) right).getStringValue()); + return processIntegerDecimalLiteralComparison(comparisonPredicate, left, literal); + } else { + return comparisonPredicate; + } + } + private Expression processDecimalV3TypeCoercion(ComparisonPredicate comparisonPredicate, Expression left, Expression right) { if (left instanceof DecimalV3Literal) { @@ -203,51 +242,113 @@ public class SimplifyComparisonPredicate extends AbstractExpressionRewriteRule { right = temp; } - if (left instanceof Cast && left.child(0).getDataType().isDecimalV3Type() - && right instanceof DecimalV3Literal) { + if (left instanceof Cast && right instanceof DecimalV3Literal) { Cast cast = (Cast) left; left = cast.child(); DecimalV3Literal literal = (DecimalV3Literal) right; - if (((DecimalV3Type) left.getDataType()) - .getScale() < ((DecimalV3Type) literal.getDataType()).getScale()) { - int toScale = ((DecimalV3Type) left.getDataType()).getScale(); - if (comparisonPredicate instanceof EqualTo) { - try { - return comparisonPredicate.withChildren(left, new DecimalV3Literal( - (DecimalV3Type) left.getDataType(), literal.getValue().setScale(toScale))); - } catch (ArithmeticException e) { - if (left.nullable()) { - // TODO: the ideal way is to return an If expr like: - // return new If(new IsNull(left), new NullLiteral(BooleanType.INSTANCE), - // BooleanLiteral.of(false)); - // but current fold constant rule can't handle such complex expr with null literal - // before supporting complex conjuncts with null literal folding rules, - // we use a trick way like this: - return new And(new IsNull(left), new NullLiteral(BooleanType.INSTANCE)); - } else { + if (left.getDataType().isDecimalV3Type()) { + if (((DecimalV3Type) left.getDataType()) + .getScale() < ((DecimalV3Type) literal.getDataType()).getScale()) { + int toScale = ((DecimalV3Type) left.getDataType()).getScale(); + if (comparisonPredicate instanceof EqualTo) { + try { + return comparisonPredicate.withChildren(left, + new DecimalV3Literal((DecimalV3Type) left.getDataType(), + literal.getValue().setScale(toScale))); + } catch (ArithmeticException e) { + if (left.nullable()) { + // TODO: the ideal way is to return an If expr like: + // return new If(new IsNull(left), new NullLiteral(BooleanType.INSTANCE), + // BooleanLiteral.of(false)); + // but current fold constant rule can't handle such complex expr with null literal + // before supporting complex conjuncts with null literal folding rules, + // we use a trick way like this: + return new And(new IsNull(left), + new NullLiteral(BooleanType.INSTANCE)); + } else { + return BooleanLiteral.of(false); + } + } + } else if (comparisonPredicate instanceof NullSafeEqual) { + try { + return comparisonPredicate.withChildren(left, + new DecimalV3Literal((DecimalV3Type) left.getDataType(), + literal.getValue().setScale(toScale))); + } catch (ArithmeticException e) { return BooleanLiteral.of(false); } + } else if (comparisonPredicate instanceof GreaterThan + || comparisonPredicate instanceof LessThanEqual) { + return comparisonPredicate.withChildren(left, literal.roundFloor(toScale)); + } else if (comparisonPredicate instanceof LessThan + || comparisonPredicate instanceof GreaterThanEqual) { + return comparisonPredicate.withChildren(left, + literal.roundCeiling(toScale)); } - } else if (comparisonPredicate instanceof NullSafeEqual) { - try { - return comparisonPredicate.withChildren(left, new DecimalV3Literal( - (DecimalV3Type) left.getDataType(), literal.getValue().setScale(toScale))); - } catch (ArithmeticException e) { + } + } else if (left.getDataType().isIntegerLikeType()) { + return processIntegerDecimalLiteralComparison(comparisonPredicate, left, + literal.getValue()); + } + } + + return comparisonPredicate; + } + + private Expression processIntegerDecimalLiteralComparison( + ComparisonPredicate comparisonPredicate, Expression left, BigDecimal literal) { + // we only process isIntegerLikeType, which are tinyint, smallint, int, bigint + if (literal.compareTo(new BigDecimal(Long.MAX_VALUE)) <= 0) { + if (literal.scale() > 0) { + if (comparisonPredicate instanceof EqualTo) { + if (left.nullable()) { + // TODO: the ideal way is to return an If expr like: + // return new If(new IsNull(left), new NullLiteral(BooleanType.INSTANCE), + // BooleanLiteral.of(false)); + // but current fold constant rule can't handle such complex expr with null literal + // before supporting complex conjuncts with null literal folding rules, + // we use a trick way like this: + return new And(new IsNull(left), new NullLiteral(BooleanType.INSTANCE)); + } else { return BooleanLiteral.of(false); } + } else if (comparisonPredicate instanceof NullSafeEqual) { + return BooleanLiteral.of(false); } else if (comparisonPredicate instanceof GreaterThan || comparisonPredicate instanceof LessThanEqual) { - return comparisonPredicate.withChildren(left, literal.roundFloor(toScale)); + return comparisonPredicate.withChildren(left, + convertDecimalToIntegerLikeLiteral( + literal.setScale(0, RoundingMode.FLOOR))); } else if (comparisonPredicate instanceof LessThan || comparisonPredicate instanceof GreaterThanEqual) { - return comparisonPredicate.withChildren(left, literal.roundCeiling(toScale)); + return comparisonPredicate.withChildren(left, + convertDecimalToIntegerLikeLiteral( + literal.setScale(0, RoundingMode.CEILING))); } + } else { + return comparisonPredicate.withChildren(left, + convertDecimalToIntegerLikeLiteral(literal)); } } - return comparisonPredicate; } + private IntegerLikeLiteral convertDecimalToIntegerLikeLiteral(BigDecimal decimal) { + Preconditions.checkArgument( + decimal.scale() == 0 && decimal.compareTo(new BigDecimal(Long.MAX_VALUE)) <= 0, + "decimal literal must have 0 scale and smaller than Long.MAX_VALUE"); + long val = decimal.longValue(); + if (val <= Byte.MAX_VALUE) { + return new TinyIntLiteral((byte) val); + } else if (val <= Short.MAX_VALUE) { + return new SmallIntLiteral((short) val); + } else if (val <= Integer.MAX_VALUE) { + return new IntegerLiteral((int) val); + } else { + return new BigIntLiteral(val); + } + } + private Expression migrateCastToDateTime(Cast cast) { //cast( cast(v as date) as datetime) if v is datetime, set left = v if (cast.child() instanceof Cast diff --git a/regression-test/suites/nereids_syntax_p0/test_simplify_comparison.groovy b/regression-test/suites/nereids_syntax_p0/test_simplify_comparison.groovy index 53c0ff9a12..4b3cd3bdca 100644 --- a/regression-test/suites/nereids_syntax_p0/test_simplify_comparison.groovy +++ b/regression-test/suites/nereids_syntax_p0/test_simplify_comparison.groovy @@ -72,4 +72,252 @@ suite("test_simplify_comparison") { } sql "select cast('1234' as decimalv3(18,4)) > 2000;" + + sql 'drop table if exists simple_test_table_t;' + sql """CREATE TABLE IF NOT EXISTS `simple_test_table_t` ( + a tinyint, + b smallint, + c int, + d bigint, + e largeint + ) ENGINE=OLAP + UNIQUE KEY (`a`) + DISTRIBUTED BY HASH(`a`) BUCKETS 120 + PROPERTIES ( + "replication_num" = "1", + "in_memory" = "false", + "compression" = "LZ4" + );""" + + explain { + sql "verbose select * from simple_test_table_t where a = cast(1.0 as double) and b = cast(1.0 as double) and c = cast(1.0 as double) and d = cast(1.0 as double);" + notContains "CAST" + } + + explain { + sql "verbose select * from simple_test_table_t where e = cast(1.0 as double);" + contains "CAST" + } + + explain { + sql "verbose select * from simple_test_table_t where a > cast(1.0 as double) and b > cast(1.0 as double) and c > cast(1.0 as double) and d > cast(1.0 as double);" + notContains "CAST" + } + + explain { + sql "verbose select * from simple_test_table_t where e > cast(1.0 as double);" + contains "CAST" + } + + explain { + sql "verbose select * from simple_test_table_t where a < cast(1.0 as double) and b < cast(1.0 as double) and c < cast(1.0 as double) and d < cast(1.0 as double);" + notContains "CAST" + } + + explain { + sql "verbose select * from simple_test_table_t where e < cast(1.0 as double);" + contains "CAST" + } + + explain { + sql "verbose select * from simple_test_table_t where a >= cast(1.0 as double) and b >= cast(1.0 as double) and c >= cast(1.0 as double) and d >= cast(1.0 as double);" + notContains "CAST" + } + + explain { + sql "verbose select * from simple_test_table_t where e >= cast(1.0 as double);" + contains "CAST" + } + + explain { + sql "verbose select * from simple_test_table_t where a <= cast(1.0 as double) and b <= cast(1.0 as double) and c <= cast(1.0 as double) and d <= cast(1.0 as double);" + notContains "CAST" + } + + explain { + sql "verbose select * from simple_test_table_t where e <= cast(1.0 as double);" + contains "CAST" + } + + explain { + sql "verbose select * from simple_test_table_t where a = cast(1.1 as double) and b = cast(1.1 as double) and c = cast(1.1 as double) and d = cast(1.1 as double);" + contains "a[#0] IS NULL" + contains "b[#1] IS NULL" + contains "c[#2] IS NULL" + contains "d[#3] IS NULL" + contains "AND NULL" + } + + explain { + sql "verbose select * from simple_test_table_t where e = cast(1.1 as double);" + contains "CAST(e[#4] AS DOUBLE) = 1.1" + } + + explain { + sql "verbose select * from simple_test_table_t where a > cast(1.1 as double) and b > cast(1.1 as double) and c > cast(1.1 as double) and d > cast(1.1 as double);" + contains "a[#0] > 1" + contains "b[#1] > 1" + contains "c[#2] > 1" + contains "d[#3] > 1" + } + + explain { + sql "verbose select * from simple_test_table_t where e > cast(1.1 as double);" + contains "CAST(e[#4] AS DOUBLE) > 1.1" + } + + explain { + sql "verbose select * from simple_test_table_t where a < cast(1.1 as double) and b < cast(1.1 as double) and c < cast(1.1 as double) and d < cast(1.1 as double);" + contains "a[#0] < 2" + contains "b[#1] < 2" + contains "c[#2] < 2" + contains "d[#3] < 2" + } + + explain { + sql "verbose select * from simple_test_table_t where e < cast(1.1 as double);" + contains "CAST(e[#4] AS DOUBLE) < 1.1" + } + + explain { + sql "verbose select * from simple_test_table_t where a >= cast(1.1 as double) and b >= cast(1.1 as double) and c >= cast(1.1 as double) and d >= cast(1.1 as double);" + contains "a[#0] >= 2" + contains "b[#1] >= 2" + contains "c[#2] >= 2" + contains "d[#3] >= 2" + } + + explain { + sql "verbose select * from simple_test_table_t where e >= cast(1.1 as double);" + contains "CAST(e[#4] AS DOUBLE) >= 1.1" + } + + explain { + sql "verbose select * from simple_test_table_t where a <= cast(1.1 as double) and b <= cast(1.1 as double) and c <= cast(1.1 as double) and d <= cast(1.1 as double);" + contains "a[#0] <= 1" + contains "b[#1] <= 1" + contains "c[#2] <= 1" + contains "d[#3] <= 1" + } + + explain { + sql "verbose select * from simple_test_table_t where e <= cast(1.1 as double);" + contains "CAST(e[#4] AS DOUBLE) <= 1.1" + } + + explain { + sql "verbose select * from simple_test_table_t where a = 1.0 and b = 1.0 and c = 1.0 and d = 1.0;" + notContains "CAST" + } + + explain { + sql "verbose select * from simple_test_table_t where e = 1.0;" + contains "CAST" + } + + explain { + sql "verbose select * from simple_test_table_t where a > 1.0 and b > 1.0 and c > 1.0 and d > 1.0;" + notContains "CAST" + } + + explain { + sql "verbose select * from simple_test_table_t where e > 1.0;" + contains "CAST" + } + + explain { + sql "verbose select * from simple_test_table_t where a < 1.0 and b < 1.0 and c < 1.0 and d < 1.0;" + notContains "CAST" + } + + explain { + sql "verbose select * from simple_test_table_t where e < 1.0;" + contains "CAST" + } + + explain { + sql "verbose select * from simple_test_table_t where a >= 1.0 and b >= 1.0 and c >= 1.0 and d >= 1.0;" + notContains "CAST" + } + + explain { + sql "verbose select * from simple_test_table_t where e >= 1.0;" + contains "CAST" + } + + explain { + sql "verbose select * from simple_test_table_t where a <= 1.0 and b <= 1.0 and c <= 1.0 and d <= 1.0;" + notContains "CAST" + } + + explain { + sql "verbose select * from simple_test_table_t where e <= 1.0;" + contains "CAST" + } + + explain { + sql "verbose select * from simple_test_table_t where a = 1.1 and b = 1.1 and c = 1.1 and d = 1.1;" + contains "a[#0] IS NULL" + contains "b[#1] IS NULL" + contains "c[#2] IS NULL" + contains "d[#3] IS NULL" + contains "AND NULL" + } + + explain { + sql "verbose select * from simple_test_table_t where e = 1.1;" + contains "CAST(e[#4] AS DOUBLE) = 1.1" + } + + explain { + sql "verbose select * from simple_test_table_t where a > 1.1 and b > 1.1 and c > 1.1 and d > 1.1;" + contains "a[#0] > 1" + contains "b[#1] > 1" + contains "c[#2] > 1" + contains "d[#3] > 1" + } + + explain { + sql "verbose select * from simple_test_table_t where e > 1.1;" + contains "CAST(e[#4] AS DOUBLE) > 1.1" + } + + explain { + sql "verbose select * from simple_test_table_t where a < 1.1 and b < 1.1 and c < 1.1 and d < 1.1;" + contains "a[#0] < 2" + contains "b[#1] < 2" + contains "c[#2] < 2" + contains "d[#3] < 2" + } + + explain { + sql "verbose select * from simple_test_table_t where e < 1.1;" + contains "CAST(e[#4] AS DOUBLE) < 1.1" + } + + explain { + sql "verbose select * from simple_test_table_t where a >= 1.1 and b >= 1.1 and c >= 1.1 and d >= 1.1;" + contains "a[#0] >= 2" + contains "b[#1] >= 2" + contains "c[#2] >= 2" + contains "d[#3] >= 2" + } + + explain { + sql "verbose select * from simple_test_table_t where e >= 1.1;" + contains "CAST(e[#4] AS DOUBLE) >= 1.1" + } + + explain { + sql "verbose select * from simple_test_table_t where a <= 1.1 and b <= 1.1 and c <= 1.1 and d <= 1.1;" + contains "a[#0] <= 1" + contains "b[#1] <= 1" + contains "c[#2] <= 1" + contains "d[#3] <= 1" + } + + explain { + sql "verbose select * from simple_test_table_t where e <= 1.1;" + contains "CAST(e[#4] AS DOUBLE) <= 1.1" + } } \ No newline at end of file --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org