This is an automated email from the ASF dual-hosted git repository.

yumwang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new ab2ae034dc7 [SPARK-42597][SQL] Support unwrap date type to timestamp 
type
ab2ae034dc7 is described below

commit ab2ae034dc70adc597d3baf0d4b7347daa55caa8
Author: Yuming Wang <[email protected]>
AuthorDate: Tue Mar 14 08:00:40 2023 +0800

    [SPARK-42597][SQL] Support unwrap date type to timestamp type
    
    ### What changes were proposed in this pull request?
    
    This PR enhance `UnwrapCastInBinaryComparison` to support unwrap date type 
to timestamp type.
    
    The way to unwrap date type to timestamp type are:
    ```
    GreaterThan(Cast(ts, DateType), date) -> GreaterThanOrEqual(ts, Cast(date + 
1, TimestampType))
    GreaterThanOrEqual(Cast(ts, DateType), date) -> GreaterThanOrEqual(ts, 
Cast(date, TimestampType))
    Equality(Cast(ts, DateType), date) -> And(GreaterThanOrEqual(ts, Cast(date, 
TimestampType)), LessThan(ts, Cast(date + 1, TimestampType)))
    LessThan(Cast(ts, DateType), date) -> LessThan(ts, Cast(date, 
TimestampType))
    LessThanOrEqual(Cast(ts, DateType), date) -> LessThan(ts, Cast(date + 1, 
TimestampType))
    ```
    
    ### Why are the changes needed?
    
    Improve query performance.
    
    A common use case. We store cold data in HDFS by partition, store hot data 
in MySQL, and then union all the results. The filter in the MySQL branch cannot 
be pushed down, which affects performance:
    ```sql
    CREATE TABLE t_cold(id bigint, start timestamp, dt date) using parquet 
PARTITIONED BY (dt);
    CREATE TABLE t_hot(id bigint, start timestamp) using 
org.apache.spark.sql.jdbc OPTIONS (`url` '***', `dbtable` 'db.t2', `user` 
'spark', `password` '***');
    CREATE VIEW all_data AS SELECT * FROM t_cold UNION ALL SELECT *, 
to_date(start) FROM t_hot;
    SELECT * FROM all_data WHERE start BETWEEN '2023-02-06' AND '2023-02-07';
    ```
    
    Before this PR | After this PR
    -- | --
    <img 
src="https://user-images.githubusercontent.com/5399861/221576723-7fc45356-65db-48e2-8d40-88420c21c9f5.png";
 width="400" height="730"> | <img 
src="https://user-images.githubusercontent.com/5399861/221575848-5b975ed0-70ab-4527-acfe-796cc20e169b.png";
 width="400" height="730">
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Unit test.
    
    Closes #40190 from wangyum/SPARK-42597.
    
    Authored-by: Yuming Wang <[email protected]>
    Signed-off-by: Yuming Wang <[email protected]>
---
 .../optimizer/UnwrapCastInBinaryComparison.scala   | 52 +++++++++++++----
 .../UnwrapCastInBinaryComparisonSuite.scala        | 68 ++++++++++++++++++----
 .../sql/UnwrapCastInComparisonEndToEndSuite.scala  | 24 ++++++++
 3 files changed, 124 insertions(+), 20 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
index f4a92760d22..d95bc694814 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
@@ -40,7 +40,7 @@ import org.apache.spark.sql.types._
  *
  * Currently this only handles cases where:
  *   1). `fromType` (of `fromExp`) and `toType` are of numeric types (i.e., 
short, int, float,
- *     decimal, etc) or boolean type
+ *     decimal, etc), boolean type or datetime type
  *   2). `fromType` can be safely coerced to `toType` without precision loss 
(e.g., short to int,
  *     int to long, but not long to int, nor int to boolean)
  *
@@ -104,16 +104,15 @@ object UnwrapCastInBinaryComparison extends 
Rule[LogicalPlan] {
     case l: LogicalPlan =>
       l.transformExpressionsUpWithPruning(
         _.containsAnyPattern(BINARY_COMPARISON, IN, INSET), ruleId) {
-        case e @ (BinaryComparison(_, _) | In(_, _) | InSet(_, _)) => 
unwrapCast(e)
+        case e @ (BinaryComparison(_, _) | In(_, _) | InSet(_, _)) => 
unwrapCast(e).getOrElse(e)
       }
   }
 
-  private def unwrapCast(exp: Expression): Expression = exp match {
+  private def unwrapCast(exp: Expression): Option[Expression] = exp match {
     // Not a canonical form. In this case we first canonicalize the expression 
by swapping the
     // literal and cast side, then process the result and swap the literal and 
cast again to
     // restore the original order.
-    case BinaryComparison(Literal(_, literalType), Cast(fromExp, toType, _, _))
-        if canImplicitlyCast(fromExp, toType, literalType) =>
+    case BinaryComparison(_: Literal, _: Cast) =>
       def swap(e: Expression): Expression = e match {
         case GreaterThan(left, right) => LessThan(right, left)
         case GreaterThanOrEqual(left, right) => LessThanOrEqual(right, left)
@@ -124,14 +123,19 @@ object UnwrapCastInBinaryComparison extends 
Rule[LogicalPlan] {
         case _ => e
       }
 
-      swap(unwrapCast(swap(exp)))
+      unwrapCast(swap(exp)).map(swap)
 
     // In case both sides have numeric type, optimize the comparison by 
removing casts or
     // moving cast to the literal side.
     case be @ BinaryComparison(
       Cast(fromExp, toType: NumericType, _, _), Literal(value, literalType))
         if canImplicitlyCast(fromExp, toType, literalType) =>
-      simplifyNumericComparison(be, fromExp, toType, value)
+      Some(simplifyNumericComparison(be, fromExp, toType, value))
+
+    case be @ BinaryComparison(
+      Cast(fromExp, _, timeZoneId, evalMode), date @ Literal(value, DateType))
+        if AnyTimestampType.acceptsType(fromExp.dataType) && value != null =>
+      Some(unwrapDateToTimestamp(be, fromExp, date, timeZoneId, evalMode))
 
     // As the analyzer makes sure that the list of In is already of the same 
data type, then the
     // rule can simply check the first literal in `in.list` can implicitly 
cast to `toType` or not,
@@ -151,7 +155,7 @@ object UnwrapCastInBinaryComparison extends 
Rule[LogicalPlan] {
           val newList = nullList.map(lit => Cast(lit, fromExp.dataType)) ++ 
canCastList
           In(fromExp, newList.toSeq)
       }
-      simplifyIn(fromExp, toType, list, buildIn).getOrElse(exp)
+      simplifyIn(fromExp, toType, list, buildIn)
 
     // The same with `In` expression, the analyzer makes sure that the hset of 
InSet is already of
     // the same data type, so simply check `fromExp.dataType` can implicitly 
cast to `toType` and
@@ -165,9 +169,9 @@ object UnwrapCastInBinaryComparison extends 
Rule[LogicalPlan] {
         fromExp,
         toType,
         hset.map(v => Literal.create(v, toType)).toSeq,
-        buildInSet).getOrElse(exp)
+        buildInSet)
 
-    case _ => exp
+    case _ => None
   }
 
   /**
@@ -293,6 +297,34 @@ object UnwrapCastInBinaryComparison extends 
Rule[LogicalPlan] {
     }
   }
 
+  /**
+   * Move the cast to the literal side, because we can only get the minimum 
value of timestamp,
+   * so some BinaryComparison needs to be changed,
+   * such as CAST(ts AS date) > DATE '2023-01-01' ===> ts >= TIMESTAMP 
'2023-01-02 00:00:00'
+   */
+  private def unwrapDateToTimestamp(
+      exp: BinaryComparison,
+      fromExp: Expression,
+      date: Literal,
+      tz: Option[String],
+      evalMode: EvalMode.Value): Expression = {
+    val dateAddOne = DateAdd(date, Literal(1, IntegerType))
+    exp match {
+      case _: GreaterThan =>
+        GreaterThanOrEqual(fromExp, Cast(dateAddOne, fromExp.dataType, tz, 
evalMode))
+      case _: GreaterThanOrEqual =>
+        GreaterThanOrEqual(fromExp, Cast(date, fromExp.dataType, tz, evalMode))
+      case Equality(_, _) =>
+        And(GreaterThanOrEqual(fromExp, Cast(date, fromExp.dataType, tz, 
evalMode)),
+          LessThan(fromExp, Cast(dateAddOne, fromExp.dataType, tz, evalMode)))
+      case _: LessThan =>
+        LessThan(fromExp, Cast(date, fromExp.dataType, tz, evalMode))
+      case _: LessThanOrEqual =>
+        LessThan(fromExp, Cast(dateAddOne, fromExp.dataType, tz, evalMode))
+      case _ => exp
+    }
+  }
+
   private def simplifyIn[IN <: Expression](
       fromExp: Expression,
       toType: NumericType,
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala
index 2e3b2708444..400f2f2c97b 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.catalyst.optimizer
 
+import java.time.{LocalDate, LocalDateTime}
+
 import scala.collection.immutable.HashSet
 
 import org.apache.spark.sql.catalyst.dsl.expressions._
@@ -39,11 +41,13 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest 
with ExpressionEvalHelp
   }
 
   val testRelation: LocalRelation = LocalRelation($"a".short, $"b".float,
-    $"c".decimal(5, 2), $"d".boolean)
+    $"c".decimal(5, 2), $"d".boolean, $"e".timestamp, $"f".timestampNTZ)
   val f: BoundReference = $"a".short.canBeNull.at(0)
   val f2: BoundReference = $"b".float.canBeNull.at(1)
   val f3: BoundReference = $"c".decimal(5, 2).canBeNull.at(2)
   val f4: BoundReference = $"d".boolean.canBeNull.at(3)
+  val f5: BoundReference = $"e".timestamp.canBeNull.at(4)
+  val f6: BoundReference = $"f".timestampNTZ.canBeNull.at(5)
 
   test("unwrap casts when literal == max") {
     val v = Short.MaxValue
@@ -368,9 +372,53 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest 
with ExpressionEvalHelp
     assertEquivalent(castInt(f4) < t, trueIfNotNull(f4))
   }
 
+  test("SPARK-42597: Support unwrap date to timestamp type") {
+    val dateLit = Literal.create(LocalDate.of(2023, 1, 1), DateType)
+    val dateAddOne = DateAdd(dateLit, Literal(1))
+    val nullLit = Literal.create(null, DateType)
+
+    assertEquivalent(
+      castDate(f5) > dateLit || castDate(f6) > dateLit,
+      f5 >= castTimestamp(dateAddOne) || f6 >= castTimestampNTZ(dateAddOne))
+    assertEquivalent(
+      castDate(f5) >= dateLit || castDate(f6) >= dateLit,
+      f5 >= castTimestamp(dateLit) || f6 >= castTimestampNTZ(dateLit))
+    assertEquivalent(
+      castDate(f5) < dateLit || castDate(f6) < dateLit,
+      f5 < castTimestamp(dateLit) || f6 < castTimestampNTZ(dateLit))
+    assertEquivalent(
+      castDate(f5) <= dateLit || castDate(f6) <= dateLit,
+      f5 < castTimestamp(dateAddOne) || f6 < castTimestampNTZ(dateAddOne))
+    assertEquivalent(
+      castDate(f5) === dateLit || castDate(f6) === dateLit,
+      (f5 >= castTimestamp(dateLit) && f5 < castTimestamp(dateAddOne)) ||
+        (f6 >= castTimestampNTZ(dateLit) && f6 < castTimestampNTZ(dateAddOne)))
+    assertEquivalent(
+      castDate(f5) <=> dateLit || castDate(f6) === dateLit,
+      (f5 >= castTimestamp(dateLit) && f5 < castTimestamp(dateAddOne)) ||
+        (f6 >= castTimestampNTZ(dateLit) && f6 < castTimestampNTZ(dateAddOne)))
+    assertEquivalent(
+      dateLit < castDate(f5) || dateLit < castDate(f6),
+      castTimestamp(dateAddOne) <= f5 || castTimestampNTZ(dateAddOne) <= f6)
+
+    // Null date literal should be handled by NullPropagation
+    assertEquivalent(castDate(f5) > nullLit || castDate(f6) > nullLit,
+      Literal.create(null, BooleanType) || Literal.create(null, BooleanType))
+  }
+
+  private val ts1 = LocalDateTime.of(2023, 1, 1, 23, 59, 59, 99999000)
+  private val ts2 = LocalDateTime.of(2023, 1, 1, 23, 59, 59, 999998000)
+  private val ts3 = LocalDateTime.of(9999, 12, 31, 23, 59, 59, 999999999)
+  private val ts4 = LocalDateTime.of(1, 1, 1, 0, 0, 0, 0)
+
   private def castInt(e: Expression): Expression = Cast(e, IntegerType)
   private def castDouble(e: Expression): Expression = Cast(e, DoubleType)
   private def castDecimal2(e: Expression): Expression = Cast(e, 
DecimalType(10, 4))
+  private def castDate(e: Expression): Expression = Cast(e, DateType)
+  private def castTimestamp(e: Expression): Expression =
+    Cast(e, TimestampType, Some(conf.sessionLocalTimeZone))
+  private def castTimestampNTZ(e: Expression): Expression =
+    Cast(e, TimestampNTZType, Some(conf.sessionLocalTimeZone))
 
   private def decimal(v: Decimal): Decimal = Decimal(v.toJavaBigDecimal, 5, 2)
   private def decimal2(v: BigDecimal): Decimal = Decimal(v, 10, 4)
@@ -383,16 +431,16 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest 
with ExpressionEvalHelp
 
     if (evaluate) {
       Seq(
-        (100.toShort, 3.14.toFloat, decimal2(100), true),
-        (-300.toShort, 3.1415927.toFloat, decimal2(-3000.50), false),
-        (null, Float.NaN, decimal2(12345.6789), null),
-        (null, null, null, null),
-        (Short.MaxValue, Float.PositiveInfinity, decimal2(Short.MaxValue), 
true),
-        (Short.MinValue, Float.NegativeInfinity, decimal2(Short.MinValue), 
false),
-        (0.toShort, Float.MaxValue, decimal2(0), null),
-        (0.toShort, Float.MinValue, decimal2(0.01), null)
+        (100.toShort, 3.14.toFloat, decimal2(100), true, ts1, ts1),
+        (-300.toShort, 3.1415927.toFloat, decimal2(-3000.50), false, ts2, ts2),
+        (null, Float.NaN, decimal2(12345.6789), null, null, null),
+        (null, null, null, null, null, null),
+        (Short.MaxValue, Float.PositiveInfinity, decimal2(Short.MaxValue), 
true, ts3, ts3),
+        (Short.MinValue, Float.NegativeInfinity, decimal2(Short.MinValue), 
false, ts4, ts4),
+        (0.toShort, Float.MaxValue, decimal2(0), null, null, null),
+        (0.toShort, Float.MinValue, decimal2(0.01), null, null, null)
       ).foreach(v => {
-        val row = create_row(v._1, v._2, v._3, v._4)
+        val row = create_row(v._1, v._2, v._3, v._4, v._5, v._6)
         checkEvaluation(e1, e2.eval(row), row)
       })
     }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/UnwrapCastInComparisonEndToEndSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/UnwrapCastInComparisonEndToEndSuite.scala
index 1d7af84ef60..468915aa493 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/UnwrapCastInComparisonEndToEndSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/UnwrapCastInComparisonEndToEndSuite.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql
 
+import java.time.LocalDateTime
+
 import 
org.apache.spark.sql.catalyst.expressions.IntegralLiteralTestUtils.{negativeInt,
 positiveInt}
 import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.sql.types.Decimal
@@ -240,5 +242,27 @@ class UnwrapCastInComparisonEndToEndSuite extends 
QueryTest with SharedSparkSess
     }
   }
 
+  test("SPARK-42597: Support unwrap date type to timestamp type") {
+    val ts1 = LocalDateTime.of(2023, 1, 1, 23, 59, 59, 99999000)
+    val ts2 = LocalDateTime.of(2023, 1, 1, 23, 59, 59, 999998000)
+    val ts3 = LocalDateTime.of(2023, 1, 2, 23, 59, 59, 8000)
+
+    withTable(t) {
+      Seq(ts1, ts2, ts3).toDF("ts").write.saveAsTable(t)
+      val df = spark.table(t)
+
+      checkAnswer(
+        df.where("cast(ts as date) > date'2023-01-01'"), Seq(ts3).map(Row(_)))
+      checkAnswer(
+        df.where("cast(ts as date) >= date'2023-01-01'"), Seq(ts1, ts2, 
ts3).map(Row(_)))
+      checkAnswer(
+        df.where("cast(ts as date) < date'2023-01-02'"), Seq(ts1, 
ts2).map(Row(_)))
+      checkAnswer(
+        df.where("cast(ts as date) <= date'2023-01-02'"), Seq(ts1, ts2, 
ts3).map(Row(_)))
+      checkAnswer(
+        df.where("cast(ts as date) = date'2023-01-01'"), Seq(ts1, 
ts2).map(Row(_)))
+    }
+  }
+
   private def decimal(v: BigDecimal): Decimal = Decimal(v, 5, 2)
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to