This is an automated email from the ASF dual-hosted git repository.

liyang pushed a commit to branch kylin5
in repository https://gitbox.apache.org/repos/asf/kylin.git

commit 3914626750dab422929767e7652971589a3061be
Author: Pengfei Zhan <dethr...@gmail.com>
AuthorDate: Mon Nov 20 11:15:31 2023 +0800

    KYLIN-5875 Enhance the simplification method of filter condition for 
segment pruning
---
 .../routing/HeterogeneousSegmentPruningTest.java   | 14 ++++-
 .../java/org/apache/kylin/query/util/RexUtils.java | 14 ++++-
 .../apache/kylin/util/FilterConditionExpander.java | 64 ++++++++++++----------
 .../kylin/query/runtime/SparderRexVisitor.scala    | 13 +++--
 .../org/apache/spark/sql/SparderTypeUtil.scala     |  2 +
 5 files changed, 68 insertions(+), 39 deletions(-)

diff --git 
a/src/kylin-it/src/test/java/org/apache/kylin/query/routing/HeterogeneousSegmentPruningTest.java
 
b/src/kylin-it/src/test/java/org/apache/kylin/query/routing/HeterogeneousSegmentPruningTest.java
index 133df6601b..d8e5437e22 100644
--- 
a/src/kylin-it/src/test/java/org/apache/kylin/query/routing/HeterogeneousSegmentPruningTest.java
+++ 
b/src/kylin-it/src/test/java/org/apache/kylin/query/routing/HeterogeneousSegmentPruningTest.java
@@ -52,7 +52,6 @@ import org.junit.Test;
 
 import lombok.val;
 
-
 public class HeterogeneousSegmentPruningTest extends 
NLocalWithSparkSessionTest {
 
     @Test
@@ -316,6 +315,15 @@ public class HeterogeneousSegmentPruningTest extends 
NLocalWithSparkSessionTest
 
         val sql2_date_string = sql + "where cal_dt >= '2012-01-03' and cal_dt 
< '2012-01-10' group by cal_dt";
         assertPrunedSegmentsRange(project, sql2_date_string, dfId, 
expectedRanges, layout_10001, null);
+        val sql2_cast_date_str = sql
+                + "where cast(cal_dt as date) >= '2012-01-03' and cast(cal_dt 
as date) < '2012-01-10' group by cal_dt";
+        assertPrunedSegmentsRange(project, sql2_cast_date_str, dfId, 
expectedRanges, layout_10001, null);
+        val sql_date_cast_literal = sql
+                + "where cal_dt >= cast('2012-01-03' as date) and cal_dt < 
cast('2012-01-10' as date) group by cal_dt";
+        assertPrunedSegmentsRange(project, sql_date_cast_literal, dfId, 
expectedRanges, layout_10001, null);
+        val sql_cast_date_cast_literal = sql
+                + "where cast(cal_dt as date) >= cast('2012-01-03' as date) 
and cast(cal_dt as date) < cast('2012-01-10' as date) group by cal_dt";
+        assertPrunedSegmentsRange(project, sql_cast_date_cast_literal, dfId, 
expectedRanges, layout_10001, null);
 
         // pruned segments do not have capable layout to answer
         val sql3_no_layout = "select trans_id from test_kylin_fact "
@@ -453,7 +461,7 @@ public class HeterogeneousSegmentPruningTest extends 
NLocalWithSparkSessionTest
         // val seg3Id = "54eaf96d-6146-45d2-b94e-d5d187f89919"
         // val seg4Id = "411f40b9-a80a-4453-90a9-409aac6f7632"
         // val seg5Id = "a8318597-cb75-416f-8eb8-96ea285dd2b4"
-        // 
+        //
         val sql = "with T1 as (select cal_dt, trans_id \n" + "from 
test_kylin_fact inner join test_account \n"
                 + "on test_kylin_fact.seller_id = test_account.account_id \n"
                 + "where cal_dt between date'2012-01-01' and 
date'2012-01-03'\n" + "group by cal_dt, trans_id),\n"
@@ -522,7 +530,7 @@ public class HeterogeneousSegmentPruningTest extends 
NLocalWithSparkSessionTest
         // val seg3Id = "54eaf96d-6146-45d2-b94e-d5d187f89919"
         // val seg4Id = "411f40b9-a80a-4453-90a9-409aac6f7632"
         // val seg5Id = "a8318597-cb75-416f-8eb8-96ea285dd2b4"
-        // 
+        //
         val sql = "with T1 as (select cal_dt, trans_id \n" + "from 
test_kylin_fact inner join test_account \n"
                 + "on test_kylin_fact.seller_id = test_account.account_id \n"
                 + "where cal_dt between date'2012-01-01' and 
date'2012-01-03'\n" + "group by cal_dt, trans_id)\n";
diff --git 
a/src/query-common/src/main/java/org/apache/kylin/query/util/RexUtils.java 
b/src/query-common/src/main/java/org/apache/kylin/query/util/RexUtils.java
index fc70c5697c..383e0f8739 100644
--- a/src/query-common/src/main/java/org/apache/kylin/query/util/RexUtils.java
+++ b/src/query-common/src/main/java/org/apache/kylin/query/util/RexUtils.java
@@ -44,6 +44,8 @@ import org.apache.calcite.sql.SqlOperator;
 import org.apache.calcite.sql.type.SqlTypeName;
 import org.apache.calcite.util.DateString;
 import org.apache.calcite.util.TimestampString;
+import org.apache.commons.lang3.StringUtils;
+import org.apache.kylin.guava30.shaded.common.base.Preconditions;
 import org.apache.kylin.guava30.shaded.common.collect.Lists;
 import org.apache.kylin.metadata.datatype.DataType;
 import org.apache.kylin.metadata.model.TblColRef;
@@ -263,13 +265,21 @@ public class RexUtils {
         RelDataType relDataType;
         switch (colType.getName()) {
         case DataType.DATE:
-            return rexBuilder.makeDateLiteral(new DateString(value));
+            // In order to support the column type is date, but the value is 
timestamp string.
+            // for example: DEFAULT.TEST_KYLIN_FACT.CAL_DT with type date,
+            // the filter condition is: cast("cal_dt" as timestamp) >= 
timestamp '2012-01-01 00:00:00',
+            // the FilterConditionExpander will translate it to compare CAL_DT 
>= date '2012-01-01'
+            // This seems like an unsafe operation.
+            String[] splits = StringUtils.split(value.trim(), " ");
+            Preconditions.checkArgument(splits.length >= 1, "split %s with 
error", value);
+            return rexBuilder.makeDateLiteral(new DateString(splits[0]));
         case DataType.TIMESTAMP:
             relDataType = 
rexBuilder.getTypeFactory().createSqlType(SqlTypeName.TIMESTAMP);
             return rexBuilder.makeTimestampLiteral(new TimestampString(value), 
relDataType.getPrecision());
         case DataType.VARCHAR:
         case DataType.STRING:
-            return rexBuilder.makeLiteral(value);
+            relDataType = 
rexBuilder.getTypeFactory().createSqlType(SqlTypeName.VARCHAR, 
colType.getPrecision());
+            return rexBuilder.makeLiteral(value, relDataType, false);
         case DataType.INTEGER:
             relDataType = 
rexBuilder.getTypeFactory().createSqlType(SqlTypeName.INTEGER);
             return rexBuilder.makeLiteral(Integer.parseInt(value), 
relDataType, false);
diff --git 
a/src/query-common/src/main/java/org/apache/kylin/util/FilterConditionExpander.java
 
b/src/query-common/src/main/java/org/apache/kylin/util/FilterConditionExpander.java
index f31bfe1ac9..c8087c6753 100644
--- 
a/src/query-common/src/main/java/org/apache/kylin/util/FilterConditionExpander.java
+++ 
b/src/query-common/src/main/java/org/apache/kylin/util/FilterConditionExpander.java
@@ -35,16 +35,17 @@ import org.apache.calcite.rex.RexInputRef;
 import org.apache.calcite.rex.RexLiteral;
 import org.apache.calcite.rex.RexNode;
 import org.apache.calcite.rex.RexUtil;
+import org.apache.calcite.sql.SqlKind;
 import org.apache.calcite.sql.fun.SqlStdOperatorTable;
-import org.apache.calcite.util.DateString;
 import org.apache.calcite.util.NlsString;
-import org.apache.calcite.util.TimestampString;
 import org.apache.kylin.common.exception.KylinException;
 import org.apache.kylin.guava30.shaded.common.collect.Lists;
+import org.apache.kylin.metadata.datatype.DataType;
 import org.apache.kylin.query.relnode.ContextUtil;
 import org.apache.kylin.query.relnode.OlapContext;
 import org.apache.kylin.query.relnode.OlapRel;
 import org.apache.kylin.query.relnode.OlapTableScan;
+import org.apache.kylin.query.util.RexUtils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -54,9 +55,6 @@ import lombok.var;
 public class FilterConditionExpander {
     public static final Logger logger = 
LoggerFactory.getLogger(FilterConditionExpander.class);
 
-    private static final String DATE = "date";
-    private static final String TIMESTAMP = "timestamp";
-
     private final OlapContext context;
     private final RelNode currentRel;
     private final RexBuilder rexBuilder;
@@ -148,31 +146,41 @@ public class FilterConditionExpander {
             }
 
             // col <op> lit
-            val op1 = call.getOperands().get(1);
-            if (call.getOperands().size() == 2 && op1 instanceof RexLiteral) {
+            return simplify(call, lInputRef);
+        }
+
+        return null;
+    }
+
+    private RexNode simplify(RexCall call, RexInputRef lInputRef) {
+        val op1 = call.getOperands().get(1);
+        if (call.getOperands().size() == 2) {
+            // accept cases: the right rel is literal
+            if (op1 instanceof RexLiteral) {
                 var rLit = (RexLiteral) op1;
-                rLit = ((RexLiteral) op1).getValue() instanceof NlsString ? 
transformRexLiteral(lInputRef, rLit) : rLit;
+                rLit = transformRexLiteral(lInputRef, rLit);
                 return rexBuilder.makeCall(call.getOperator(), lInputRef, 
rLit);
             }
+            // accept cases: the right rel is cast(literal as datatype)
+            if (op1.isA(SqlKind.CAST)) {
+                RexCall c = (RexCall) op1;
+                RexNode rexNode = c.getOperands().get(0);
+                if (rexNode instanceof RexLiteral) {
+                    RexLiteral rLit = transformRexLiteral(lInputRef, 
(RexLiteral) rexNode);
+                    return rexBuilder.makeCall(call.getOperator(), lInputRef, 
rLit);
+                }
+            }
         }
-
         return null;
     }
 
     private RexNode convertIn(RexInputRef rexInputRef, List<RexNode> 
extendedOperands, boolean isIn) {
-        val transformedOperands = Lists.<RexNode> newArrayList();
+        List<RexNode> transformedOperands = Lists.newArrayList();
         for (RexNode operand : extendedOperands) {
-            RexNode transformedOperand;
             if (!(operand instanceof RexLiteral)) {
                 return null;
             }
-            if (((RexLiteral) operand).getValue() instanceof NlsString) {
-                val transformed = transformRexLiteral(rexInputRef, 
(RexLiteral) operand);
-                transformedOperand = transformed == null ? operand : 
transformed;
-            } else {
-                transformedOperand = operand;
-            }
-
+            RexNode transformedOperand = transformRexLiteral(rexInputRef, 
(RexLiteral) operand);
             val operator = isIn ? SqlStdOperatorTable.EQUALS : 
SqlStdOperatorTable.NOT_EQUALS;
             transformedOperands.add(rexBuilder.makeCall(operator, rexInputRef, 
transformedOperand));
         }
@@ -186,18 +194,18 @@ public class FilterConditionExpander {
     }
 
     private RexLiteral transformRexLiteral(RexInputRef inputRef, RexLiteral 
operand2) {
-        val literalValue = operand2.getValue();
-        val literalValueInString = ((NlsString) literalValue).getValue();
-        val typeName = inputRef.getType().getSqlTypeName().getName();
+        DataType dataType = 
DataType.getType(inputRef.getType().getSqlTypeName().getName());
+        String value;
+        if (operand2.getValue() instanceof NlsString) {
+            value = RexLiteral.stringValue(operand2);
+        } else {
+            Comparable c = RexLiteral.value(operand2);
+            value = c == null ? null : c.toString();
+        }
         try {
-            if (typeName.equalsIgnoreCase(DATE)) {
-                return rexBuilder.makeDateLiteral(new 
DateString(literalValueInString));
-            } else if (typeName.equalsIgnoreCase(TIMESTAMP)) {
-                return rexBuilder.makeTimestampLiteral(new 
TimestampString(literalValueInString),
-                        inputRef.getType().getPrecision());
-            }
+            return (RexLiteral) RexUtils.transformValue2RexLiteral(rexBuilder, 
value, dataType);
         } catch (Exception ex) {
-            logger.warn("transform Date/Timestamp RexLiteral for filterRel 
failed", ex);
+            logger.warn("transform rexLiteral({}) failed: {}", 
RexLiteral.value(operand2), ex.getMessage());
         }
         return operand2;
     }
diff --git 
a/src/spark-project/sparder/src/main/scala/org/apache/kylin/query/runtime/SparderRexVisitor.scala
 
b/src/spark-project/sparder/src/main/scala/org/apache/kylin/query/runtime/SparderRexVisitor.scala
index b8e34366b8..8361122437 100644
--- 
a/src/spark-project/sparder/src/main/scala/org/apache/kylin/query/runtime/SparderRexVisitor.scala
+++ 
b/src/spark-project/sparder/src/main/scala/org/apache/kylin/query/runtime/SparderRexVisitor.scala
@@ -19,7 +19,6 @@
 package org.apache.kylin.query.runtime
 
 
-import java.math.BigDecimal
 import java.sql.Timestamp
 import java.time.ZoneId
 
@@ -324,23 +323,25 @@ class SparderRexVisitor(val inputFieldNames: Seq[String],
         if (Seq("MONTH", "YEAR", "QUARTER").contains(
           t.getIntervalQualifier.timeUnitRange.name)) {
           return Some(
-            
MonthNum(k_lit(literal.getValue.asInstanceOf[BigDecimal].intValue)))
+            MonthNum(k_lit(RexLiteral.intValue(literal))))
         }
         if (literal.getType.getFamily
           .asInstanceOf[SqlTypeFamily] == SqlTypeFamily.INTERVAL_DAY_TIME) {
           return Some(
             SparderTypeUtil.toSparkTimestamp(
-              new java.math.BigDecimal(literal.getValue.toString).longValue()))
+              new 
java.math.BigDecimal(RexLiteral.value(literal).toString).longValue()))
         }
       }
 
       case literalSql: BasicSqlType => {
+        val literalStr = RexLiteral.value(literal).toString
         literalSql.getSqlTypeName match {
           case SqlTypeName.DATE =>
-            return Some(stringToTime(literal.toString))
+            return Some(stringToTime(literalStr))
           case SqlTypeName.TIMESTAMP =>
-            return 
Some(toJavaTimestamp(stringToTimestamp(UTF8String.fromString(literal.toString),
-              ZoneId.systemDefault()).head))
+            val string = UTF8String.fromString(literalStr)
+            val maybeLong = stringToTimestamp(string, ZoneId.systemDefault())
+            return Some(toJavaTimestamp(maybeLong.head))
           case _ =>
         }
       }
diff --git 
a/src/spark-project/spark-common/src/main/scala/org/apache/spark/sql/SparderTypeUtil.scala
 
b/src/spark-project/spark-common/src/main/scala/org/apache/spark/sql/SparderTypeUtil.scala
index acce92b76a..d70c569a8c 100644
--- 
a/src/spark-project/spark-common/src/main/scala/org/apache/spark/sql/SparderTypeUtil.scala
+++ 
b/src/spark-project/spark-common/src/main/scala/org/apache/spark/sql/SparderTypeUtil.scala
@@ -229,6 +229,8 @@ object SparderTypeUtil extends Logging {
             b.floatValue()
           case SqlTypeName.SMALLINT =>
             b.shortValue()
+          case SqlTypeName.TINYINT =>
+            b.byteValue()
           case _ =>
             b
         }

Reply via email to