This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new f555e9983f6d [SPARK-46069][SQL][FOLLOWUP] Simplify the algorithm and
add comments
f555e9983f6d is described below
commit f555e9983f6dcea90066c3f4678fb11e11c6949e
Author: Wenchen Fan <[email protected]>
AuthorDate: Wed Dec 20 16:10:02 2023 +0800
[SPARK-46069][SQL][FOLLOWUP] Simplify the algorithm and add comments
### What changes were proposed in this pull request?
This is a followup of https://github.com/apache/spark/pull/43982, to
simplify the algorithm without the "add one day" operation. This makes the
algorithm easier to document.
### Why are the changes needed?
code simplication.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
updated tests
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #44403 from cloud-fan/minor.
Authored-by: Wenchen Fan <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../optimizer/UnwrapCastInBinaryComparison.scala | 50 ++++++++++++++++------
.../UnwrapCastInBinaryComparisonSuite.scala | 48 ++++++++++-----------
2 files changed, 59 insertions(+), 39 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
index dd516afeb58c..19af9f5a2b55 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
@@ -341,36 +341,58 @@ object UnwrapCastInBinaryComparison extends
Rule[LogicalPlan] {
ts: Literal,
tz: Option[String],
evalMode: EvalMode.Value): Expression = {
- val floorDate = Cast(ts, fromExp.dataType, tz, evalMode)
- val dateAddOne = DateAdd(floorDate, Literal(1, IntegerType))
- val isStartOfDay =
- EqualTo(ts, Cast(floorDate, ts.dataType, tz,
evalMode)).eval(EmptyRow).asInstanceOf[Boolean]
+ assert(fromExp.dataType == DateType)
+ val floorDate = Literal(Cast(ts, DateType, tz, evalMode).eval(), DateType)
+ val timePartsAllZero =
+ EqualTo(ts, Cast(floorDate, ts.dataType, tz,
evalMode)).eval().asInstanceOf[Boolean]
exp match {
case _: GreaterThan =>
+ // "CAST(date AS TIMESTAMP) > timestamp" ==> "date > floor_date", no
matter the
+ // timestamp has non-zero time part or not.
GreaterThan(fromExp, floorDate)
+ case _: LessThanOrEqual =>
+ // "CAST(date AS TIMESTAMP) <= timestamp" ==> "date <= floor_date",
no matter the
+ // timestamp has non-zero time part or not.
+ LessThanOrEqual(fromExp, floorDate)
case _: GreaterThanOrEqual =>
- if (isStartOfDay) {
+ if (!timePartsAllZero) {
+ // "CAST(date AS TIMESTAMP) >= timestamp" ==> "date > floor_date",
if the timestamp has
+ // non-zero time part.
+ GreaterThan(fromExp, floorDate)
+ } else {
+ // If the timestamp's time parts are all zero, the date can also be
the floor_date.
GreaterThanOrEqual(fromExp, floorDate)
+ }
+ case _: LessThan =>
+ if (!timePartsAllZero) {
+ // "CAST(date AS TIMESTAMP) < timestamp" ==> "date <= floor_date",
if the timestamp has
+ // non-zero time part.
+ LessThanOrEqual(fromExp, floorDate)
} else {
- GreaterThanOrEqual(fromExp, dateAddOne)
+ // If the timestamp's time parts are all zero, the date can not be
the floor_date.
+ LessThan(fromExp, floorDate)
}
case _: EqualTo =>
- if (isStartOfDay) {
+ if (timePartsAllZero) {
+ // "CAST(date AS TIMESTAMP) = timestamp" ==> "date = floor_date",
if the timestamp's
+ // time parts are all zero
EqualTo(fromExp, floorDate)
} else {
+ // if the timestamp has non-zero time part, then we always get false
unless the date is
+ // null, in which case the result is also null.
falseIfNotNull(fromExp)
}
case _: EqualNullSafe =>
- if (isStartOfDay) EqualNullSafe(fromExp, floorDate) else FalseLiteral
- case _: LessThan =>
- if (isStartOfDay) {
- LessThan(fromExp, floorDate)
+ if (timePartsAllZero) {
+ // "CAST(date AS TIMESTAMP) <=> timestamp" ==> "date <=>
floor_date", if the timestamp's
+ // time parts are all zero
+ EqualNullSafe(fromExp, floorDate)
} else {
- LessThan(fromExp, dateAddOne)
+ // if the timestamp has non-zero time part, then we always get false
because this is
+ // null-safe equal comparison.
+ FalseLiteral
}
- case _: LessThanOrEqual =>
- LessThanOrEqual(fromExp, floorDate)
case _ => exp
}
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala
index 646f32729607..5ccff73e5238 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala
@@ -407,43 +407,41 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest
with ExpressionEvalHelp
}
test("SPARK-46069: Support unwrap timestamp type to date type") {
- def doTest(
- tsLit: Literal,
- isStartOfDay: Boolean,
- castTimestampFunc: Expression => Expression): Unit = {
- val floorDate = Cast(tsLit, DateType, Some(conf.sessionLocalTimeZone))
- val dateAddOne = DateAdd(floorDate, Literal(1, IntegerType))
- assertEquivalent(castTimestampFunc(f7) > tsLit, f7 > floorDate)
- if (isStartOfDay) {
- assertEquivalent(castTimestampFunc(f7) >= tsLit, f7 >= floorDate)
- assertEquivalent(castTimestampFunc(f7) === tsLit, f7 === floorDate)
- assertEquivalent(castTimestampFunc(f7) <=> tsLit, f7 <=> floorDate)
- assertEquivalent(castTimestampFunc(f7) < tsLit, f7 < floorDate)
+ def doTest(tsLit: Literal, timePartsAllZero: Boolean): Unit = {
+ val tz = Some(conf.sessionLocalTimeZone)
+ val floorDate = Literal(Cast(tsLit, DateType, tz).eval(), DateType)
+ val dateToTsCast = Cast(f7, tsLit.dataType, tz)
+
+ assertEquivalent(dateToTsCast > tsLit, f7 > floorDate)
+ assertEquivalent(dateToTsCast <= tsLit, f7 <= floorDate)
+ if (timePartsAllZero) {
+ assertEquivalent(dateToTsCast >= tsLit, f7 >= floorDate)
+ assertEquivalent(dateToTsCast < tsLit, f7 < floorDate)
+ assertEquivalent(dateToTsCast === tsLit, f7 === floorDate)
+ assertEquivalent(dateToTsCast <=> tsLit, f7 <=> floorDate)
} else {
- assertEquivalent(castTimestampFunc(f7) >= tsLit, f7 >= dateAddOne)
- assertEquivalent(castTimestampFunc(f7) === tsLit, f7.isNull &&
Literal(null, BooleanType))
- assertEquivalent(castTimestampFunc(f7) <=> tsLit, FalseLiteral)
- assertEquivalent(castTimestampFunc(f7) < tsLit, f7 < dateAddOne)
+ assertEquivalent(dateToTsCast >= tsLit, f7 > floorDate)
+ assertEquivalent(dateToTsCast < tsLit, f7 <= floorDate)
+ assertEquivalent(dateToTsCast === tsLit, f7.isNull && Literal(null,
BooleanType))
+ assertEquivalent(dateToTsCast <=> tsLit, FalseLiteral)
}
- assertEquivalent(castTimestampFunc(f7) <= tsLit, f7 <= floorDate)
}
- // Test isStartOfDay is true cases
+ // Test timestamp with all its time parts as 0.
val micros = SparkDateTimeUtils.daysToMicros(19704,
ZoneId.of(conf.sessionLocalTimeZone))
val instant = java.time.Instant.ofEpochSecond(micros / 1000000)
val tsLit = Literal.create(instant, TimestampType)
- doTest(tsLit, isStartOfDay = true, castTimestamp)
+ doTest(tsLit, timePartsAllZero = true)
val tsNTZ = LocalDateTime.of(2023, 12, 13, 0, 0, 0, 0)
val tsNTZLit = Literal.create(tsNTZ, TimestampNTZType)
- doTest(tsNTZLit, isStartOfDay = true, castTimestampNTZ)
+ doTest(tsNTZLit, timePartsAllZero = true)
- // Test isStartOfDay is false cases
+ // Test timestamp with non-zero time parts.
val tsLit2 = Literal.create(instant.plusSeconds(30), TimestampType)
- val tsNTZ2 = LocalDateTime.of(2023, 12, 13, 0, 0, 30, 0)
- doTest(tsLit2, isStartOfDay = false, castTimestamp)
- val tsNTZLit2 = Literal.create(tsNTZ2, TimestampNTZType)
- doTest(tsNTZLit2, isStartOfDay = false, castTimestampNTZ)
+ doTest(tsLit2, timePartsAllZero = false)
+ val tsNTZLit2 = Literal.create(tsNTZ.withSecond(30), TimestampNTZType)
+ doTest(tsNTZLit2, timePartsAllZero = false)
}
private val ts1 = LocalDateTime.of(2023, 1, 1, 23, 59, 59, 99999000)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]