This is an automated email from the ASF dual-hosted git repository.
yumwang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new ab2ae034dc7 [SPARK-42597][SQL] Support unwrap date type to timestamp
type
ab2ae034dc7 is described below
commit ab2ae034dc70adc597d3baf0d4b7347daa55caa8
Author: Yuming Wang <[email protected]>
AuthorDate: Tue Mar 14 08:00:40 2023 +0800
[SPARK-42597][SQL] Support unwrap date type to timestamp type
### What changes were proposed in this pull request?
This PR enhance `UnwrapCastInBinaryComparison` to support unwrap date type
to timestamp type.
The way to unwrap date type to timestamp type are:
```
GreaterThan(Cast(ts, DateType), date) -> GreaterThanOrEqual(ts, Cast(date +
1, TimestampType))
GreaterThanOrEqual(Cast(ts, DateType), date) -> GreaterThanOrEqual(ts,
Cast(date, TimestampType))
Equality(Cast(ts, DateType), date) -> And(GreaterThanOrEqual(ts, Cast(date,
TimestampType)), LessThan(ts, Cast(date + 1, TimestampType)))
LessThan(Cast(ts, DateType), date) -> LessThan(ts, Cast(date,
TimestampType))
LessThanOrEqual(Cast(ts, DateType), date) -> LessThan(ts, Cast(date + 1,
TimestampType))
```
### Why are the changes needed?
Improve query performance.
A common use case. We store cold data in HDFS by partition, store hot data
in MySQL, and then union all the results. The filter in the MySQL branch cannot
be pushed down, which affects performance:
```sql
CREATE TABLE t_cold(id bigint, start timestamp, dt date) using parquet
PARTITIONED BY (dt);
CREATE TABLE t_hot(id bigint, start timestamp) using
org.apache.spark.sql.jdbc OPTIONS (`url` '***', `dbtable` 'db.t2', `user`
'spark', `password` '***');
CREATE VIEW all_data AS SELECT * FROM t_cold UNION ALL SELECT *,
to_date(start) FROM t_hot;
SELECT * FROM all_data WHERE start BETWEEN '2023-02-06' AND '2023-02-07';
```
Before this PR | After this PR
-- | --
<img
src="https://user-images.githubusercontent.com/5399861/221576723-7fc45356-65db-48e2-8d40-88420c21c9f5.png"
width="400" height="730"> | <img
src="https://user-images.githubusercontent.com/5399861/221575848-5b975ed0-70ab-4527-acfe-796cc20e169b.png"
width="400" height="730">
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Unit test.
Closes #40190 from wangyum/SPARK-42597.
Authored-by: Yuming Wang <[email protected]>
Signed-off-by: Yuming Wang <[email protected]>
---
.../optimizer/UnwrapCastInBinaryComparison.scala | 52 +++++++++++++----
.../UnwrapCastInBinaryComparisonSuite.scala | 68 ++++++++++++++++++----
.../sql/UnwrapCastInComparisonEndToEndSuite.scala | 24 ++++++++
3 files changed, 124 insertions(+), 20 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
index f4a92760d22..d95bc694814 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
@@ -40,7 +40,7 @@ import org.apache.spark.sql.types._
*
* Currently this only handles cases where:
* 1). `fromType` (of `fromExp`) and `toType` are of numeric types (i.e.,
short, int, float,
- * decimal, etc) or boolean type
+ * decimal, etc), boolean type or datetime type
* 2). `fromType` can be safely coerced to `toType` without precision loss
(e.g., short to int,
* int to long, but not long to int, nor int to boolean)
*
@@ -104,16 +104,15 @@ object UnwrapCastInBinaryComparison extends
Rule[LogicalPlan] {
case l: LogicalPlan =>
l.transformExpressionsUpWithPruning(
_.containsAnyPattern(BINARY_COMPARISON, IN, INSET), ruleId) {
- case e @ (BinaryComparison(_, _) | In(_, _) | InSet(_, _)) =>
unwrapCast(e)
+ case e @ (BinaryComparison(_, _) | In(_, _) | InSet(_, _)) =>
unwrapCast(e).getOrElse(e)
}
}
- private def unwrapCast(exp: Expression): Expression = exp match {
+ private def unwrapCast(exp: Expression): Option[Expression] = exp match {
// Not a canonical form. In this case we first canonicalize the expression
by swapping the
// literal and cast side, then process the result and swap the literal and
cast again to
// restore the original order.
- case BinaryComparison(Literal(_, literalType), Cast(fromExp, toType, _, _))
- if canImplicitlyCast(fromExp, toType, literalType) =>
+ case BinaryComparison(_: Literal, _: Cast) =>
def swap(e: Expression): Expression = e match {
case GreaterThan(left, right) => LessThan(right, left)
case GreaterThanOrEqual(left, right) => LessThanOrEqual(right, left)
@@ -124,14 +123,19 @@ object UnwrapCastInBinaryComparison extends
Rule[LogicalPlan] {
case _ => e
}
- swap(unwrapCast(swap(exp)))
+ unwrapCast(swap(exp)).map(swap)
// In case both sides have numeric type, optimize the comparison by
removing casts or
// moving cast to the literal side.
case be @ BinaryComparison(
Cast(fromExp, toType: NumericType, _, _), Literal(value, literalType))
if canImplicitlyCast(fromExp, toType, literalType) =>
- simplifyNumericComparison(be, fromExp, toType, value)
+ Some(simplifyNumericComparison(be, fromExp, toType, value))
+
+ case be @ BinaryComparison(
+ Cast(fromExp, _, timeZoneId, evalMode), date @ Literal(value, DateType))
+ if AnyTimestampType.acceptsType(fromExp.dataType) && value != null =>
+ Some(unwrapDateToTimestamp(be, fromExp, date, timeZoneId, evalMode))
// As the analyzer makes sure that the list of In is already of the same
data type, then the
// rule can simply check the first literal in `in.list` can implicitly
cast to `toType` or not,
@@ -151,7 +155,7 @@ object UnwrapCastInBinaryComparison extends
Rule[LogicalPlan] {
val newList = nullList.map(lit => Cast(lit, fromExp.dataType)) ++
canCastList
In(fromExp, newList.toSeq)
}
- simplifyIn(fromExp, toType, list, buildIn).getOrElse(exp)
+ simplifyIn(fromExp, toType, list, buildIn)
// The same with `In` expression, the analyzer makes sure that the hset of
InSet is already of
// the same data type, so simply check `fromExp.dataType` can implicitly
cast to `toType` and
@@ -165,9 +169,9 @@ object UnwrapCastInBinaryComparison extends
Rule[LogicalPlan] {
fromExp,
toType,
hset.map(v => Literal.create(v, toType)).toSeq,
- buildInSet).getOrElse(exp)
+ buildInSet)
- case _ => exp
+ case _ => None
}
/**
@@ -293,6 +297,34 @@ object UnwrapCastInBinaryComparison extends
Rule[LogicalPlan] {
}
}
+ /**
+ * Move the cast to the literal side, because we can only get the minimum
value of timestamp,
+ * so some BinaryComparison needs to be changed,
+ * such as CAST(ts AS date) > DATE '2023-01-01' ===> ts >= TIMESTAMP
'2023-01-02 00:00:00'
+ */
+ private def unwrapDateToTimestamp(
+ exp: BinaryComparison,
+ fromExp: Expression,
+ date: Literal,
+ tz: Option[String],
+ evalMode: EvalMode.Value): Expression = {
+ val dateAddOne = DateAdd(date, Literal(1, IntegerType))
+ exp match {
+ case _: GreaterThan =>
+ GreaterThanOrEqual(fromExp, Cast(dateAddOne, fromExp.dataType, tz,
evalMode))
+ case _: GreaterThanOrEqual =>
+ GreaterThanOrEqual(fromExp, Cast(date, fromExp.dataType, tz, evalMode))
+ case Equality(_, _) =>
+ And(GreaterThanOrEqual(fromExp, Cast(date, fromExp.dataType, tz,
evalMode)),
+ LessThan(fromExp, Cast(dateAddOne, fromExp.dataType, tz, evalMode)))
+ case _: LessThan =>
+ LessThan(fromExp, Cast(date, fromExp.dataType, tz, evalMode))
+ case _: LessThanOrEqual =>
+ LessThan(fromExp, Cast(dateAddOne, fromExp.dataType, tz, evalMode))
+ case _ => exp
+ }
+ }
+
private def simplifyIn[IN <: Expression](
fromExp: Expression,
toType: NumericType,
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala
index 2e3b2708444..400f2f2c97b 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.optimizer
+import java.time.{LocalDate, LocalDateTime}
+
import scala.collection.immutable.HashSet
import org.apache.spark.sql.catalyst.dsl.expressions._
@@ -39,11 +41,13 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest
with ExpressionEvalHelp
}
val testRelation: LocalRelation = LocalRelation($"a".short, $"b".float,
- $"c".decimal(5, 2), $"d".boolean)
+ $"c".decimal(5, 2), $"d".boolean, $"e".timestamp, $"f".timestampNTZ)
val f: BoundReference = $"a".short.canBeNull.at(0)
val f2: BoundReference = $"b".float.canBeNull.at(1)
val f3: BoundReference = $"c".decimal(5, 2).canBeNull.at(2)
val f4: BoundReference = $"d".boolean.canBeNull.at(3)
+ val f5: BoundReference = $"e".timestamp.canBeNull.at(4)
+ val f6: BoundReference = $"f".timestampNTZ.canBeNull.at(5)
test("unwrap casts when literal == max") {
val v = Short.MaxValue
@@ -368,9 +372,53 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest
with ExpressionEvalHelp
assertEquivalent(castInt(f4) < t, trueIfNotNull(f4))
}
+ test("SPARK-42597: Support unwrap date to timestamp type") {
+ val dateLit = Literal.create(LocalDate.of(2023, 1, 1), DateType)
+ val dateAddOne = DateAdd(dateLit, Literal(1))
+ val nullLit = Literal.create(null, DateType)
+
+ assertEquivalent(
+ castDate(f5) > dateLit || castDate(f6) > dateLit,
+ f5 >= castTimestamp(dateAddOne) || f6 >= castTimestampNTZ(dateAddOne))
+ assertEquivalent(
+ castDate(f5) >= dateLit || castDate(f6) >= dateLit,
+ f5 >= castTimestamp(dateLit) || f6 >= castTimestampNTZ(dateLit))
+ assertEquivalent(
+ castDate(f5) < dateLit || castDate(f6) < dateLit,
+ f5 < castTimestamp(dateLit) || f6 < castTimestampNTZ(dateLit))
+ assertEquivalent(
+ castDate(f5) <= dateLit || castDate(f6) <= dateLit,
+ f5 < castTimestamp(dateAddOne) || f6 < castTimestampNTZ(dateAddOne))
+ assertEquivalent(
+ castDate(f5) === dateLit || castDate(f6) === dateLit,
+ (f5 >= castTimestamp(dateLit) && f5 < castTimestamp(dateAddOne)) ||
+ (f6 >= castTimestampNTZ(dateLit) && f6 < castTimestampNTZ(dateAddOne)))
+ assertEquivalent(
+ castDate(f5) <=> dateLit || castDate(f6) === dateLit,
+ (f5 >= castTimestamp(dateLit) && f5 < castTimestamp(dateAddOne)) ||
+ (f6 >= castTimestampNTZ(dateLit) && f6 < castTimestampNTZ(dateAddOne)))
+ assertEquivalent(
+ dateLit < castDate(f5) || dateLit < castDate(f6),
+ castTimestamp(dateAddOne) <= f5 || castTimestampNTZ(dateAddOne) <= f6)
+
+ // Null date literal should be handled by NullPropagation
+ assertEquivalent(castDate(f5) > nullLit || castDate(f6) > nullLit,
+ Literal.create(null, BooleanType) || Literal.create(null, BooleanType))
+ }
+
+ private val ts1 = LocalDateTime.of(2023, 1, 1, 23, 59, 59, 99999000)
+ private val ts2 = LocalDateTime.of(2023, 1, 1, 23, 59, 59, 999998000)
+ private val ts3 = LocalDateTime.of(9999, 12, 31, 23, 59, 59, 999999999)
+ private val ts4 = LocalDateTime.of(1, 1, 1, 0, 0, 0, 0)
+
private def castInt(e: Expression): Expression = Cast(e, IntegerType)
private def castDouble(e: Expression): Expression = Cast(e, DoubleType)
private def castDecimal2(e: Expression): Expression = Cast(e,
DecimalType(10, 4))
+ private def castDate(e: Expression): Expression = Cast(e, DateType)
+ private def castTimestamp(e: Expression): Expression =
+ Cast(e, TimestampType, Some(conf.sessionLocalTimeZone))
+ private def castTimestampNTZ(e: Expression): Expression =
+ Cast(e, TimestampNTZType, Some(conf.sessionLocalTimeZone))
private def decimal(v: Decimal): Decimal = Decimal(v.toJavaBigDecimal, 5, 2)
private def decimal2(v: BigDecimal): Decimal = Decimal(v, 10, 4)
@@ -383,16 +431,16 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest
with ExpressionEvalHelp
if (evaluate) {
Seq(
- (100.toShort, 3.14.toFloat, decimal2(100), true),
- (-300.toShort, 3.1415927.toFloat, decimal2(-3000.50), false),
- (null, Float.NaN, decimal2(12345.6789), null),
- (null, null, null, null),
- (Short.MaxValue, Float.PositiveInfinity, decimal2(Short.MaxValue),
true),
- (Short.MinValue, Float.NegativeInfinity, decimal2(Short.MinValue),
false),
- (0.toShort, Float.MaxValue, decimal2(0), null),
- (0.toShort, Float.MinValue, decimal2(0.01), null)
+ (100.toShort, 3.14.toFloat, decimal2(100), true, ts1, ts1),
+ (-300.toShort, 3.1415927.toFloat, decimal2(-3000.50), false, ts2, ts2),
+ (null, Float.NaN, decimal2(12345.6789), null, null, null),
+ (null, null, null, null, null, null),
+ (Short.MaxValue, Float.PositiveInfinity, decimal2(Short.MaxValue),
true, ts3, ts3),
+ (Short.MinValue, Float.NegativeInfinity, decimal2(Short.MinValue),
false, ts4, ts4),
+ (0.toShort, Float.MaxValue, decimal2(0), null, null, null),
+ (0.toShort, Float.MinValue, decimal2(0.01), null, null, null)
).foreach(v => {
- val row = create_row(v._1, v._2, v._3, v._4)
+ val row = create_row(v._1, v._2, v._3, v._4, v._5, v._6)
checkEvaluation(e1, e2.eval(row), row)
})
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/UnwrapCastInComparisonEndToEndSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/UnwrapCastInComparisonEndToEndSuite.scala
index 1d7af84ef60..468915aa493 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/UnwrapCastInComparisonEndToEndSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/UnwrapCastInComparisonEndToEndSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql
+import java.time.LocalDateTime
+
import
org.apache.spark.sql.catalyst.expressions.IntegralLiteralTestUtils.{negativeInt,
positiveInt}
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.Decimal
@@ -240,5 +242,27 @@ class UnwrapCastInComparisonEndToEndSuite extends
QueryTest with SharedSparkSess
}
}
+ test("SPARK-42597: Support unwrap date type to timestamp type") {
+ val ts1 = LocalDateTime.of(2023, 1, 1, 23, 59, 59, 99999000)
+ val ts2 = LocalDateTime.of(2023, 1, 1, 23, 59, 59, 999998000)
+ val ts3 = LocalDateTime.of(2023, 1, 2, 23, 59, 59, 8000)
+
+ withTable(t) {
+ Seq(ts1, ts2, ts3).toDF("ts").write.saveAsTable(t)
+ val df = spark.table(t)
+
+ checkAnswer(
+ df.where("cast(ts as date) > date'2023-01-01'"), Seq(ts3).map(Row(_)))
+ checkAnswer(
+ df.where("cast(ts as date) >= date'2023-01-01'"), Seq(ts1, ts2,
ts3).map(Row(_)))
+ checkAnswer(
+ df.where("cast(ts as date) < date'2023-01-02'"), Seq(ts1,
ts2).map(Row(_)))
+ checkAnswer(
+ df.where("cast(ts as date) <= date'2023-01-02'"), Seq(ts1, ts2,
ts3).map(Row(_)))
+ checkAnswer(
+ df.where("cast(ts as date) = date'2023-01-01'"), Seq(ts1,
ts2).map(Row(_)))
+ }
+ }
+
private def decimal(v: BigDecimal): Decimal = Decimal(v, 5, 2)
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]