This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new f555e9983f6d [SPARK-46069][SQL][FOLLOWUP] Simplify the algorithm and 
add comments
f555e9983f6d is described below

commit f555e9983f6dcea90066c3f4678fb11e11c6949e
Author: Wenchen Fan <[email protected]>
AuthorDate: Wed Dec 20 16:10:02 2023 +0800

    [SPARK-46069][SQL][FOLLOWUP] Simplify the algorithm and add comments
    
    ### What changes were proposed in this pull request?
    
    This is a followup of https://github.com/apache/spark/pull/43982, to 
simplify the algorithm without the "add one day" operation. This makes the 
algorithm easier to document.
    
    ### Why are the changes needed?
    
    code simplication.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    updated tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    no
    
    Closes #44403 from cloud-fan/minor.
    
    Authored-by: Wenchen Fan <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../optimizer/UnwrapCastInBinaryComparison.scala   | 50 ++++++++++++++++------
 .../UnwrapCastInBinaryComparisonSuite.scala        | 48 ++++++++++-----------
 2 files changed, 59 insertions(+), 39 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
index dd516afeb58c..19af9f5a2b55 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
@@ -341,36 +341,58 @@ object UnwrapCastInBinaryComparison extends 
Rule[LogicalPlan] {
       ts: Literal,
       tz: Option[String],
       evalMode: EvalMode.Value): Expression = {
-    val floorDate = Cast(ts, fromExp.dataType, tz, evalMode)
-    val dateAddOne = DateAdd(floorDate, Literal(1, IntegerType))
-    val isStartOfDay =
-      EqualTo(ts, Cast(floorDate, ts.dataType, tz, 
evalMode)).eval(EmptyRow).asInstanceOf[Boolean]
+    assert(fromExp.dataType == DateType)
+    val floorDate = Literal(Cast(ts, DateType, tz, evalMode).eval(), DateType)
+    val timePartsAllZero =
+      EqualTo(ts, Cast(floorDate, ts.dataType, tz, 
evalMode)).eval().asInstanceOf[Boolean]
 
     exp match {
       case _: GreaterThan =>
+        // "CAST(date AS TIMESTAMP) > timestamp"  ==>  "date > floor_date", no 
matter the
+        // timestamp has non-zero time part or not.
         GreaterThan(fromExp, floorDate)
+      case _: LessThanOrEqual =>
+        // "CAST(date AS TIMESTAMP) <= timestamp"  ==>  "date <= floor_date", 
no matter the
+        // timestamp has non-zero time part or not.
+        LessThanOrEqual(fromExp, floorDate)
       case _: GreaterThanOrEqual =>
-        if (isStartOfDay) {
+        if (!timePartsAllZero) {
+          // "CAST(date AS TIMESTAMP) >= timestamp"  ==>  "date > floor_date", 
if the timestamp has
+          // non-zero time part.
+          GreaterThan(fromExp, floorDate)
+        } else {
+          // If the timestamp's time parts are all zero, the date can also be 
the floor_date.
           GreaterThanOrEqual(fromExp, floorDate)
+        }
+      case _: LessThan =>
+        if (!timePartsAllZero) {
+          // "CAST(date AS TIMESTAMP) < timestamp"  ==>  "date <= floor_date", 
if the timestamp has
+          // non-zero time part.
+          LessThanOrEqual(fromExp, floorDate)
         } else {
-          GreaterThanOrEqual(fromExp, dateAddOne)
+          // If the timestamp's time parts are all zero, the date can not be 
the floor_date.
+          LessThan(fromExp, floorDate)
         }
       case _: EqualTo =>
-        if (isStartOfDay) {
+        if (timePartsAllZero) {
+          // "CAST(date AS TIMESTAMP) = timestamp"  ==>  "date = floor_date", 
if the timestamp's
+          // time parts are all zero
           EqualTo(fromExp, floorDate)
         } else {
+          // if the timestamp has non-zero time part, then we always get false 
unless the date is
+          // null, in which case the result is also null.
           falseIfNotNull(fromExp)
         }
       case _: EqualNullSafe =>
-        if (isStartOfDay) EqualNullSafe(fromExp, floorDate) else FalseLiteral
-      case _: LessThan =>
-        if (isStartOfDay) {
-          LessThan(fromExp, floorDate)
+        if (timePartsAllZero) {
+          // "CAST(date AS TIMESTAMP) <=> timestamp"  ==>  "date <=> 
floor_date", if the timestamp's
+          // time parts are all zero
+          EqualNullSafe(fromExp, floorDate)
         } else {
-          LessThan(fromExp, dateAddOne)
+          // if the timestamp has non-zero time part, then we always get false 
because this is
+          // null-safe equal comparison.
+          FalseLiteral
         }
-      case _: LessThanOrEqual =>
-        LessThanOrEqual(fromExp, floorDate)
       case _ => exp
     }
   }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala
index 646f32729607..5ccff73e5238 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala
@@ -407,43 +407,41 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest 
with ExpressionEvalHelp
   }
 
   test("SPARK-46069: Support unwrap timestamp type to date type") {
-    def doTest(
-        tsLit: Literal,
-        isStartOfDay: Boolean,
-        castTimestampFunc: Expression => Expression): Unit = {
-      val floorDate = Cast(tsLit, DateType, Some(conf.sessionLocalTimeZone))
-      val dateAddOne = DateAdd(floorDate, Literal(1, IntegerType))
-      assertEquivalent(castTimestampFunc(f7) > tsLit, f7 > floorDate)
-      if (isStartOfDay) {
-        assertEquivalent(castTimestampFunc(f7) >= tsLit, f7 >= floorDate)
-        assertEquivalent(castTimestampFunc(f7) === tsLit, f7 === floorDate)
-        assertEquivalent(castTimestampFunc(f7) <=> tsLit, f7 <=> floorDate)
-        assertEquivalent(castTimestampFunc(f7) < tsLit, f7 < floorDate)
+    def doTest(tsLit: Literal, timePartsAllZero: Boolean): Unit = {
+      val tz = Some(conf.sessionLocalTimeZone)
+      val floorDate = Literal(Cast(tsLit, DateType, tz).eval(), DateType)
+      val dateToTsCast = Cast(f7, tsLit.dataType, tz)
+
+      assertEquivalent(dateToTsCast > tsLit, f7 > floorDate)
+      assertEquivalent(dateToTsCast <= tsLit, f7 <= floorDate)
+      if (timePartsAllZero) {
+        assertEquivalent(dateToTsCast >= tsLit, f7 >= floorDate)
+        assertEquivalent(dateToTsCast < tsLit, f7 < floorDate)
+        assertEquivalent(dateToTsCast === tsLit, f7 === floorDate)
+        assertEquivalent(dateToTsCast <=> tsLit, f7 <=> floorDate)
       } else {
-        assertEquivalent(castTimestampFunc(f7) >= tsLit, f7 >= dateAddOne)
-        assertEquivalent(castTimestampFunc(f7) === tsLit, f7.isNull && 
Literal(null, BooleanType))
-        assertEquivalent(castTimestampFunc(f7) <=> tsLit, FalseLiteral)
-        assertEquivalent(castTimestampFunc(f7) < tsLit, f7 < dateAddOne)
+        assertEquivalent(dateToTsCast >= tsLit, f7 > floorDate)
+        assertEquivalent(dateToTsCast < tsLit, f7 <= floorDate)
+        assertEquivalent(dateToTsCast === tsLit, f7.isNull && Literal(null, 
BooleanType))
+        assertEquivalent(dateToTsCast <=> tsLit, FalseLiteral)
       }
-      assertEquivalent(castTimestampFunc(f7) <= tsLit, f7 <= floorDate)
     }
 
-    // Test isStartOfDay is true cases
+    // Test timestamp with all its time parts as 0.
     val micros = SparkDateTimeUtils.daysToMicros(19704, 
ZoneId.of(conf.sessionLocalTimeZone))
     val instant = java.time.Instant.ofEpochSecond(micros / 1000000)
     val tsLit = Literal.create(instant, TimestampType)
-    doTest(tsLit, isStartOfDay = true, castTimestamp)
+    doTest(tsLit, timePartsAllZero = true)
 
     val tsNTZ = LocalDateTime.of(2023, 12, 13, 0, 0, 0, 0)
     val tsNTZLit = Literal.create(tsNTZ, TimestampNTZType)
-    doTest(tsNTZLit, isStartOfDay = true, castTimestampNTZ)
+    doTest(tsNTZLit, timePartsAllZero = true)
 
-    // Test isStartOfDay is false cases
+    // Test timestamp with non-zero time parts.
     val tsLit2 = Literal.create(instant.plusSeconds(30), TimestampType)
-    val tsNTZ2 = LocalDateTime.of(2023, 12, 13, 0, 0, 30, 0)
-    doTest(tsLit2, isStartOfDay = false, castTimestamp)
-    val tsNTZLit2 = Literal.create(tsNTZ2, TimestampNTZType)
-    doTest(tsNTZLit2, isStartOfDay = false, castTimestampNTZ)
+    doTest(tsLit2, timePartsAllZero = false)
+    val tsNTZLit2 = Literal.create(tsNTZ.withSecond(30), TimestampNTZType)
+    doTest(tsNTZLit2, timePartsAllZero = false)
   }
 
   private val ts1 = LocalDateTime.of(2023, 1, 1, 23, 59, 59, 99999000)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to