[SPARK-18936][SQL] Infrastructure for session local timezone support.
## What changes were proposed in this pull request?
As of Spark 2.1, Spark SQL assumes the machine timezone for datetime
manipulation, which is bad if users are not in the same timezones as the
machines, or if different users have different timezones.
We should introduce a session local timezone setting that is used for execution.
An explicit non-goal is locale handling.
### Semantics
Setting the session local timezone means that the timezone-aware expressions
listed below should use the timezone to evaluate values, and also it should be
used to convert (cast) between string and timestamp or between timestamp and
date.
- `CurrentDate`
- `CurrentBatchTimestamp`
- `Hour`
- `Minute`
- `Second`
- `DateFormatClass`
- `ToUnixTimestamp`
- `UnixTimestamp`
- `FromUnixTime`
and below are implicitly timezone-aware through cast from timestamp to date:
- `DayOfYear`
- `Year`
- `Quarter`
- `Month`
- `DayOfMonth`
- `WeekOfYear`
- `LastDay`
- `NextDay`
- `TruncDate`
For example, if you have timestamp `"2016-01-01 00:00:00"` in `GMT`, the values
evaluated by some of timezone-aware expressions are:
```scala
scala> val df = Seq(new java.sql.Timestamp(1451606400000L)).toDF("ts")
df: org.apache.spark.sql.DataFrame = [ts: timestamp]
scala> df.selectExpr("cast(ts as string)", "year(ts)", "month(ts)",
"dayofmonth(ts)", "hour(ts)", "minute(ts)", "second(ts)").show(truncate = false)
+-------------------+----------------------+-----------------------+----------------------------+--------+----------+----------+
|ts |year(CAST(ts AS DATE))|month(CAST(ts AS
DATE))|dayofmonth(CAST(ts AS DATE))|hour(ts)|minute(ts)|second(ts)|
+-------------------+----------------------+-----------------------+----------------------------+--------+----------+----------+
|2016-01-01 00:00:00|2016 |1 |1
|0 |0 |0 |
+-------------------+----------------------+-----------------------+----------------------------+--------+----------+----------+
```
whereas setting the session local timezone to `"PST"`, they are:
```scala
scala> spark.conf.set("spark.sql.session.timeZone", "PST")
scala> df.selectExpr("cast(ts as string)", "year(ts)", "month(ts)",
"dayofmonth(ts)", "hour(ts)", "minute(ts)", "second(ts)").show(truncate = false)
+-------------------+----------------------+-----------------------+----------------------------+--------+----------+----------+
|ts |year(CAST(ts AS DATE))|month(CAST(ts AS
DATE))|dayofmonth(CAST(ts AS DATE))|hour(ts)|minute(ts)|second(ts)|
+-------------------+----------------------+-----------------------+----------------------------+--------+----------+----------+
|2015-12-31 16:00:00|2015 |12 |31
|16 |0 |0 |
+-------------------+----------------------+-----------------------+----------------------------+--------+----------+----------+
```
Notice that even if you set the session local timezone, it affects only in
`DataFrame` operations, neither in `Dataset` operations, `RDD` operations nor
in `ScalaUDF`s. You need to properly handle timezone by yourself.
### Design of the fix
I introduced an analyzer to pass session local timezone to timezone-aware
expressions and modified DateTimeUtils to take the timezone argument.
## How was this patch tested?
Existing tests and added tests for timezone aware expressions.
Author: Takuya UESHIN <[email protected]>
Closes #16308 from ueshin/issues/SPARK-18350.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2969fb43
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2969fb43
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2969fb43
Branch: refs/heads/master
Commit: 2969fb4370120a39dae98be716b24dcc0ada2cef
Parents: 7045b8b
Author: Takuya UESHIN <[email protected]>
Authored: Thu Jan 26 11:51:05 2017 +0100
Committer: Herman van Hovell <[email protected]>
Committed: Thu Jan 26 11:51:05 2017 +0100
----------------------------------------------------------------------
.../spark/sql/catalyst/CatalystConf.scala | 7 +-
.../spark/sql/catalyst/analysis/Analyzer.scala | 16 +-
.../spark/sql/catalyst/catalog/interface.scala | 6 +-
.../spark/sql/catalyst/expressions/Cast.scala | 35 +-
.../expressions/datetimeExpressions.scala | 253 ++++++---
.../sql/catalyst/optimizer/Optimizer.scala | 10 +-
.../sql/catalyst/optimizer/expressions.scala | 12 +-
.../sql/catalyst/optimizer/finishAnalysis.scala | 18 +-
.../spark/sql/catalyst/util/DateTimeUtils.scala | 143 +++--
.../sql/catalyst/analysis/AnalysisSuite.scala | 5 +-
.../ResolveGroupingAnalyticsSuite.scala | 7 +-
.../sql/catalyst/expressions/CastSuite.scala | 255 ++++-----
.../expressions/DateExpressionsSuite.scala | 525 ++++++++++++-------
.../BinaryComparisonSimplificationSuite.scala | 3 +-
.../optimizer/BooleanSimplificationSuite.scala | 2 +-
.../optimizer/CombiningLimitsSuite.scala | 3 +-
.../optimizer/DecimalAggregatesSuite.scala | 3 +-
.../catalyst/optimizer/OptimizeInSuite.scala | 2 +-
.../plans/ConstraintPropagationSuite.scala | 51 +-
.../sql/catalyst/util/DateTimeUtilsSuite.scala | 402 +++++++-------
.../scala/org/apache/spark/sql/Column.scala | 2 +-
.../scala/org/apache/spark/sql/Dataset.scala | 10 +-
.../apache/spark/sql/catalyst/SQLBuilder.scala | 2 +-
.../execution/OptimizeMetadataOnlyQuery.scala | 5 +-
.../spark/sql/execution/QueryExecution.scala | 26 +-
.../datasources/FileFormatWriter.scala | 4 +-
.../PartitioningAwareFileIndex.scala | 5 +-
.../streaming/IncrementalExecution.scala | 2 +-
.../execution/streaming/StreamExecution.scala | 2 +-
.../org/apache/spark/sql/internal/SQLConf.scala | 10 +-
.../org/apache/spark/sql/DataFrameSuite.scala | 25 +
.../hive/execution/HiveCompatibilitySuite.scala | 5 +
32 files changed, 1182 insertions(+), 674 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/2969fb43/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
index b805cfe..0b6fa56 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst
+import java.util.TimeZone
+
import org.apache.spark.sql.catalyst.analysis._
/**
@@ -36,6 +38,8 @@ trait CatalystConf {
def warehousePath: String
+ def sessionLocalTimeZone: String
+
/** If true, cartesian products between relations will be allowed for all
* join types(inner, (left|right|full) outer).
* If false, cartesian products will require explicit CROSS JOIN syntax.
@@ -68,5 +72,6 @@ case class SimpleCatalystConf(
runSQLonFile: Boolean = true,
crossJoinEnabled: Boolean = false,
cboEnabled: Boolean = false,
- warehousePath: String = "/user/hive/warehouse")
+ warehousePath: String = "/user/hive/warehouse",
+ sessionLocalTimeZone: String = TimeZone.getDefault().getID)
extends CatalystConf
http://git-wip-us.apache.org/repos/asf/spark/blob/2969fb43/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index cb56e94..8ec3304 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -155,6 +155,8 @@ class Analyzer(
HandleNullInputsForUDF),
Batch("FixNullability", Once,
FixNullability),
+ Batch("ResolveTimeZone", Once,
+ ResolveTimeZone),
Batch("Cleanup", fixedPoint,
CleanupAliases)
)
@@ -223,7 +225,7 @@ class Analyzer(
case ne: NamedExpression => ne
case e if !e.resolved => u
case g: Generator => MultiAlias(g, Nil)
- case c @ Cast(ne: NamedExpression, _) => Alias(c, ne.name)()
+ case c @ Cast(ne: NamedExpression, _, _) => Alias(c, ne.name)()
case e: ExtractValue => Alias(e, toPrettySQL(e))()
case e if optGenAliasFunc.isDefined =>
Alias(child, optGenAliasFunc.get.apply(e))()
@@ -2312,6 +2314,18 @@ class Analyzer(
}
}
}
+
+ /**
+ * Replace [[TimeZoneAwareExpression]] without [[TimeZone]] by its copy with
session local
+ * time zone.
+ */
+ object ResolveTimeZone extends Rule[LogicalPlan] {
+
+ override def apply(plan: LogicalPlan): LogicalPlan =
plan.resolveExpressions {
+ case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty =>
+ e.withTimeZone(conf.sessionLocalTimeZone)
+ }
+ }
}
/**
http://git-wip-us.apache.org/repos/asf/spark/blob/2969fb43/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
index b8dc5f9..a8fa78d 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
@@ -26,10 +26,10 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier,
InternalRow, TableIden
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap,
Cast, Literal}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.quoteIdentifier
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types.StructType
-
/**
* A function defined in the catalog.
*
@@ -114,7 +114,9 @@ case class CatalogTablePartition(
*/
def toRow(partitionSchema: StructType): InternalRow = {
InternalRow.fromSeq(partitionSchema.map { field =>
- Cast(Literal(spec(field.name)), field.dataType).eval()
+ // TODO: use correct timezone for partition values.
+ Cast(Literal(spec(field.name)), field.dataType,
+ Option(DateTimeUtils.defaultTimeZone().getID)).eval()
})
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/2969fb43/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index ad59271..a36d350 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -131,7 +131,12 @@ object Cast {
private def resolvableNullability(from: Boolean, to: Boolean) = !from || to
}
-/** Cast the child expression to the target data type. */
+/**
+ * Cast the child expression to the target data type.
+ *
+ * When cast from/to timezone related types, we need timeZoneId, which will be
resolved with
+ * session local timezone by an analyzer [[ResolveTimeZone]].
+ */
@ExpressionDescription(
usage = "_FUNC_(expr AS type) - Casts the value `expr` to the target data
type `type`.",
extended = """
@@ -139,7 +144,10 @@ object Cast {
> SELECT _FUNC_('10' as int);
10
""")
-case class Cast(child: Expression, dataType: DataType) extends UnaryExpression
with NullIntolerant {
+case class Cast(child: Expression, dataType: DataType, timeZoneId:
Option[String] = None)
+ extends UnaryExpression with TimeZoneAwareExpression with NullIntolerant {
+
+ def this(child: Expression, dataType: DataType) = this(child, dataType, None)
override def toString: String = s"cast($child as ${dataType.simpleString})"
@@ -154,6 +162,9 @@ case class Cast(child: Expression, dataType: DataType)
extends UnaryExpression w
override def nullable: Boolean = Cast.forceNullable(child.dataType,
dataType) || child.nullable
+ override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
+ copy(timeZoneId = Option(timeZoneId))
+
// [[func]] assumes the input is no longer null because eval already does
the null check.
@inline private[this] def buildCast[T](a: Any, func: T => Any): Any =
func(a.asInstanceOf[T])
@@ -162,7 +173,7 @@ case class Cast(child: Expression, dataType: DataType)
extends UnaryExpression w
case BinaryType => buildCast[Array[Byte]](_, UTF8String.fromBytes)
case DateType => buildCast[Int](_, d =>
UTF8String.fromString(DateTimeUtils.dateToString(d)))
case TimestampType => buildCast[Long](_,
- t => UTF8String.fromString(DateTimeUtils.timestampToString(t)))
+ t => UTF8String.fromString(DateTimeUtils.timestampToString(t, timeZone)))
case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString))
}
@@ -207,7 +218,7 @@ case class Cast(child: Expression, dataType: DataType)
extends UnaryExpression w
// TimestampConverter
private[this] def castToTimestamp(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[UTF8String](_, utfs =>
DateTimeUtils.stringToTimestamp(utfs).orNull)
+ buildCast[UTF8String](_, utfs => DateTimeUtils.stringToTimestamp(utfs,
timeZone).orNull)
case BooleanType =>
buildCast[Boolean](_, b => if (b) 1L else 0)
case LongType =>
@@ -219,7 +230,7 @@ case class Cast(child: Expression, dataType: DataType)
extends UnaryExpression w
case ByteType =>
buildCast[Byte](_, b => longToTimestamp(b.toLong))
case DateType =>
- buildCast[Int](_, d => DateTimeUtils.daysToMillis(d) * 1000)
+ buildCast[Int](_, d => DateTimeUtils.daysToMillis(d, timeZone) * 1000)
// TimestampWritable.decimalToTimestamp
case DecimalType() =>
buildCast[Decimal](_, d => decimalToTimestamp(d))
@@ -254,7 +265,7 @@ case class Cast(child: Expression, dataType: DataType)
extends UnaryExpression w
case TimestampType =>
// throw valid precision more than seconds, according to Hive.
// Timestamp.nanos is in 0 to 999,999,999, no more than a second.
- buildCast[Long](_, t => DateTimeUtils.millisToDays(t / 1000L))
+ buildCast[Long](_, t => DateTimeUtils.millisToDays(t / 1000L, timeZone))
}
// IntervalConverter
@@ -531,8 +542,9 @@ case class Cast(child: Expression, dataType: DataType)
extends UnaryExpression w
(c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString(
org.apache.spark.sql.catalyst.util.DateTimeUtils.dateToString($c));"""
case TimestampType =>
+ val tz = ctx.addReferenceMinorObj(timeZone)
(c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString(
-
org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c));"""
+
org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c, $tz));"""
case _ =>
(c, evPrim, evNull) => s"$evPrim =
UTF8String.fromString(String.valueOf($c));"
}
@@ -558,8 +570,9 @@ case class Cast(child: Expression, dataType: DataType)
extends UnaryExpression w
}
"""
case TimestampType =>
+ val tz = ctx.addReferenceMinorObj(timeZone)
(c, evPrim, evNull) =>
- s"$evPrim =
org.apache.spark.sql.catalyst.util.DateTimeUtils.millisToDays($c / 1000L);";
+ s"$evPrim =
org.apache.spark.sql.catalyst.util.DateTimeUtils.millisToDays($c / 1000L, $tz);"
case _ =>
(c, evPrim, evNull) => s"$evNull = true;"
}
@@ -637,11 +650,12 @@ case class Cast(child: Expression, dataType: DataType)
extends UnaryExpression w
from: DataType,
ctx: CodegenContext): CastFunction = from match {
case StringType =>
+ val tz = ctx.addReferenceMinorObj(timeZone)
val longOpt = ctx.freshName("longOpt")
(c, evPrim, evNull) =>
s"""
scala.Option<Long> $longOpt =
-
org.apache.spark.sql.catalyst.util.DateTimeUtils.stringToTimestamp($c);
+
org.apache.spark.sql.catalyst.util.DateTimeUtils.stringToTimestamp($c, $tz);
if ($longOpt.isDefined()) {
$evPrim = ((Long) $longOpt.get()).longValue();
} else {
@@ -653,8 +667,9 @@ case class Cast(child: Expression, dataType: DataType)
extends UnaryExpression w
case _: IntegralType =>
(c, evPrim, evNull) => s"$evPrim = ${longToTimeStampCode(c)};"
case DateType =>
+ val tz = ctx.addReferenceMinorObj(timeZone)
(c, evPrim, evNull) =>
- s"$evPrim =
org.apache.spark.sql.catalyst.util.DateTimeUtils.daysToMillis($c) * 1000;"
+ s"$evPrim =
org.apache.spark.sql.catalyst.util.DateTimeUtils.daysToMillis($c, $tz) * 1000;"
case DecimalType() =>
(c, evPrim, evNull) => s"$evPrim = ${decimalToTimestampCode(c)};"
case DoubleType =>
http://git-wip-us.apache.org/repos/asf/spark/blob/2969fb43/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
index ef1ac36..bad8a71 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
@@ -18,10 +18,10 @@
package org.apache.spark.sql.catalyst.expressions
import java.sql.Timestamp
-import java.text.SimpleDateFormat
-import java.util.{Calendar, Locale, TimeZone}
+import java.text.DateFormat
+import java.util.{Calendar, TimeZone}
-import scala.util.Try
+import scala.util.control.NonFatal
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext,
CodegenFallback, ExprCode}
@@ -30,6 +30,20 @@ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
/**
+ * Common base class for time zone aware expressions.
+ */
+trait TimeZoneAwareExpression extends Expression {
+
+ /** the timezone ID to be used to evaluate value. */
+ def timeZoneId: Option[String]
+
+ /** Returns a copy of this expression with the specified timeZoneId. */
+ def withTimeZone(timeZoneId: String): TimeZoneAwareExpression
+
+ @transient lazy val timeZone: TimeZone = TimeZone.getTimeZone(timeZoneId.get)
+}
+
+/**
* Returns the current date at the start of query evaluation.
* All calls of current_date within the same query return the same value.
*
@@ -37,14 +51,21 @@ import org.apache.spark.unsafe.types.{CalendarInterval,
UTF8String}
*/
@ExpressionDescription(
usage = "_FUNC_() - Returns the current date at the start of query
evaluation.")
-case class CurrentDate() extends LeafExpression with CodegenFallback {
+case class CurrentDate(timeZoneId: Option[String] = None)
+ extends LeafExpression with TimeZoneAwareExpression with CodegenFallback {
+
+ def this() = this(None)
+
override def foldable: Boolean = true
override def nullable: Boolean = false
override def dataType: DataType = DateType
+ override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
+ copy(timeZoneId = Option(timeZoneId))
+
override def eval(input: InternalRow): Any = {
- DateTimeUtils.millisToDays(System.currentTimeMillis())
+ DateTimeUtils.millisToDays(System.currentTimeMillis(), timeZone)
}
override def prettyName: String = "current_date"
@@ -78,11 +99,19 @@ case class CurrentTimestamp() extends LeafExpression with
CodegenFallback {
*
* There is no code generation since this expression should be replaced with a
literal.
*/
-case class CurrentBatchTimestamp(timestampMs: Long, dataType: DataType)
- extends LeafExpression with Nondeterministic with CodegenFallback {
+case class CurrentBatchTimestamp(
+ timestampMs: Long,
+ dataType: DataType,
+ timeZoneId: Option[String] = None)
+ extends LeafExpression with TimeZoneAwareExpression with Nondeterministic
with CodegenFallback {
+
+ def this(timestampMs: Long, dataType: DataType) = this(timestampMs,
dataType, None)
override def nullable: Boolean = false
+ override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
+ copy(timeZoneId = Option(timeZoneId))
+
override def prettyName: String = "current_batch_timestamp"
override protected def initializeInternal(partitionIndex: Int): Unit = {}
@@ -96,7 +125,7 @@ case class CurrentBatchTimestamp(timestampMs: Long,
dataType: DataType)
def toLiteral: Literal = dataType match {
case _: TimestampType =>
Literal(DateTimeUtils.fromJavaTimestamp(new Timestamp(timestampMs)),
TimestampType)
- case _: DateType => Literal(DateTimeUtils.millisToDays(timestampMs),
DateType)
+ case _: DateType => Literal(DateTimeUtils.millisToDays(timestampMs,
timeZone), DateType)
}
}
@@ -172,19 +201,26 @@ case class DateSub(startDate: Expression, days:
Expression)
> SELECT _FUNC_('2009-07-30 12:58:59');
12
""")
-case class Hour(child: Expression) extends UnaryExpression with
ImplicitCastInputTypes {
+case class Hour(child: Expression, timeZoneId: Option[String] = None)
+ extends UnaryExpression with TimeZoneAwareExpression with
ImplicitCastInputTypes {
+
+ def this(child: Expression) = this(child, None)
override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType)
override def dataType: DataType = IntegerType
+ override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
+ copy(timeZoneId = Option(timeZoneId))
+
override protected def nullSafeEval(timestamp: Any): Any = {
- DateTimeUtils.getHours(timestamp.asInstanceOf[Long])
+ DateTimeUtils.getHours(timestamp.asInstanceOf[Long], timeZone)
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val tz = ctx.addReferenceMinorObj(timeZone)
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
- defineCodeGen(ctx, ev, c => s"$dtu.getHours($c)")
+ defineCodeGen(ctx, ev, c => s"$dtu.getHours($c, $tz)")
}
}
@@ -195,19 +231,26 @@ case class Hour(child: Expression) extends
UnaryExpression with ImplicitCastInpu
> SELECT _FUNC_('2009-07-30 12:58:59');
58
""")
-case class Minute(child: Expression) extends UnaryExpression with
ImplicitCastInputTypes {
+case class Minute(child: Expression, timeZoneId: Option[String] = None)
+ extends UnaryExpression with TimeZoneAwareExpression with
ImplicitCastInputTypes {
+
+ def this(child: Expression) = this(child, None)
override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType)
override def dataType: DataType = IntegerType
+ override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
+ copy(timeZoneId = Option(timeZoneId))
+
override protected def nullSafeEval(timestamp: Any): Any = {
- DateTimeUtils.getMinutes(timestamp.asInstanceOf[Long])
+ DateTimeUtils.getMinutes(timestamp.asInstanceOf[Long], timeZone)
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val tz = ctx.addReferenceMinorObj(timeZone)
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
- defineCodeGen(ctx, ev, c => s"$dtu.getMinutes($c)")
+ defineCodeGen(ctx, ev, c => s"$dtu.getMinutes($c, $tz)")
}
}
@@ -218,19 +261,26 @@ case class Minute(child: Expression) extends
UnaryExpression with ImplicitCastIn
> SELECT _FUNC_('2009-07-30 12:58:59');
59
""")
-case class Second(child: Expression) extends UnaryExpression with
ImplicitCastInputTypes {
+case class Second(child: Expression, timeZoneId: Option[String] = None)
+ extends UnaryExpression with TimeZoneAwareExpression with
ImplicitCastInputTypes {
+
+ def this(child: Expression) = this(child, None)
override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType)
override def dataType: DataType = IntegerType
+ override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
+ copy(timeZoneId = Option(timeZoneId))
+
override protected def nullSafeEval(timestamp: Any): Any = {
- DateTimeUtils.getSeconds(timestamp.asInstanceOf[Long])
+ DateTimeUtils.getSeconds(timestamp.asInstanceOf[Long], timeZone)
}
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode):
ExprCode = {
+ val tz = ctx.addReferenceMinorObj(timeZone)
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
- defineCodeGen(ctx, ev, c => s"$dtu.getSeconds($c)")
+ defineCodeGen(ctx, ev, c => s"$dtu.getSeconds($c, $tz)")
}
}
@@ -401,22 +451,28 @@ case class WeekOfYear(child: Expression) extends
UnaryExpression with ImplicitCa
2016
""")
// scalastyle:on line.size.limit
-case class DateFormatClass(left: Expression, right: Expression) extends
BinaryExpression
- with ImplicitCastInputTypes {
+case class DateFormatClass(left: Expression, right: Expression, timeZoneId:
Option[String] = None)
+ extends BinaryExpression with TimeZoneAwareExpression with
ImplicitCastInputTypes {
+
+ def this(left: Expression, right: Expression) = this(left, right, None)
override def dataType: DataType = StringType
override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType,
StringType)
+ override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
+ copy(timeZoneId = Option(timeZoneId))
+
override protected def nullSafeEval(timestamp: Any, format: Any): Any = {
- val sdf = new SimpleDateFormat(format.toString, Locale.US)
- UTF8String.fromString(sdf.format(new
java.util.Date(timestamp.asInstanceOf[Long] / 1000)))
+ val df = DateTimeUtils.newDateFormat(format.toString, timeZone)
+ UTF8String.fromString(df.format(new
java.util.Date(timestamp.asInstanceOf[Long] / 1000)))
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- val sdf = classOf[SimpleDateFormat].getName
+ val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
+ val tz = ctx.addReferenceMinorObj(timeZone)
defineCodeGen(ctx, ev, (timestamp, format) => {
- s"""UTF8String.fromString((new $sdf($format.toString()))
+ s"""UTF8String.fromString($dtu.newDateFormat($format.toString(), $tz)
.format(new java.util.Date($timestamp / 1000)))"""
})
}
@@ -435,10 +491,20 @@ case class DateFormatClass(left: Expression, right:
Expression) extends BinaryEx
> SELECT _FUNC_('2016-04-08', 'yyyy-MM-dd');
1460041200
""")
-case class ToUnixTimestamp(timeExp: Expression, format: Expression) extends
UnixTime {
+case class ToUnixTimestamp(
+ timeExp: Expression,
+ format: Expression,
+ timeZoneId: Option[String] = None)
+ extends UnixTime {
+
+ def this(timeExp: Expression, format: Expression) = this(timeExp, format,
None)
+
override def left: Expression = timeExp
override def right: Expression = format
+ override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
+ copy(timeZoneId = Option(timeZoneId))
+
def this(time: Expression) = {
this(time, Literal("yyyy-MM-dd HH:mm:ss"))
}
@@ -465,10 +531,17 @@ case class ToUnixTimestamp(timeExp: Expression, format:
Expression) extends Unix
> SELECT _FUNC_('2016-04-08', 'yyyy-MM-dd');
1460041200
""")
-case class UnixTimestamp(timeExp: Expression, format: Expression) extends
UnixTime {
+case class UnixTimestamp(timeExp: Expression, format: Expression, timeZoneId:
Option[String] = None)
+ extends UnixTime {
+
+ def this(timeExp: Expression, format: Expression) = this(timeExp, format,
None)
+
override def left: Expression = timeExp
override def right: Expression = format
+ override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
+ copy(timeZoneId = Option(timeZoneId))
+
def this(time: Expression) = {
this(time, Literal("yyyy-MM-dd HH:mm:ss"))
}
@@ -480,7 +553,8 @@ case class UnixTimestamp(timeExp: Expression, format:
Expression) extends UnixTi
override def prettyName: String = "unix_timestamp"
}
-abstract class UnixTime extends BinaryExpression with ExpectsInputTypes {
+abstract class UnixTime
+ extends BinaryExpression with TimeZoneAwareExpression with ExpectsInputTypes
{
override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(StringType, DateType, TimestampType), StringType)
@@ -489,8 +563,12 @@ abstract class UnixTime extends BinaryExpression with
ExpectsInputTypes {
override def nullable: Boolean = true
private lazy val constFormat: UTF8String =
right.eval().asInstanceOf[UTF8String]
- private lazy val formatter: SimpleDateFormat =
- Try(new SimpleDateFormat(constFormat.toString, Locale.US)).getOrElse(null)
+ private lazy val formatter: DateFormat =
+ try {
+ DateTimeUtils.newDateFormat(constFormat.toString, timeZone)
+ } catch {
+ case NonFatal(_) => null
+ }
override def eval(input: InternalRow): Any = {
val t = left.eval(input)
@@ -499,15 +577,19 @@ abstract class UnixTime extends BinaryExpression with
ExpectsInputTypes {
} else {
left.dataType match {
case DateType =>
- DateTimeUtils.daysToMillis(t.asInstanceOf[Int]) / 1000L
+ DateTimeUtils.daysToMillis(t.asInstanceOf[Int], timeZone) / 1000L
case TimestampType =>
t.asInstanceOf[Long] / 1000000L
case StringType if right.foldable =>
if (constFormat == null || formatter == null) {
null
} else {
- Try(formatter.parse(
- t.asInstanceOf[UTF8String].toString).getTime /
1000L).getOrElse(null)
+ try {
+ formatter.parse(
+ t.asInstanceOf[UTF8String].toString).getTime / 1000L
+ } catch {
+ case NonFatal(_) => null
+ }
}
case StringType =>
val f = right.eval(input)
@@ -515,8 +597,12 @@ abstract class UnixTime extends BinaryExpression with
ExpectsInputTypes {
null
} else {
val formatString = f.asInstanceOf[UTF8String].toString
- Try(new SimpleDateFormat(formatString, Locale.US).parse(
- t.asInstanceOf[UTF8String].toString).getTime /
1000L).getOrElse(null)
+ try {
+ DateTimeUtils.newDateFormat(formatString, timeZone).parse(
+ t.asInstanceOf[UTF8String].toString).getTime / 1000L
+ } catch {
+ case NonFatal(_) => null
+ }
}
}
}
@@ -525,11 +611,11 @@ abstract class UnixTime extends BinaryExpression with
ExpectsInputTypes {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
left.dataType match {
case StringType if right.foldable =>
- val sdf = classOf[SimpleDateFormat].getName
+ val df = classOf[DateFormat].getName
if (formatter == null) {
ExprCode("", "true", ctx.defaultValue(dataType))
} else {
- val formatterName = ctx.addReferenceObj("formatter", formatter, sdf)
+ val formatterName = ctx.addReferenceObj("formatter", formatter, df)
val eval1 = left.genCode(ctx)
ev.copy(code = s"""
${eval1.code}
@@ -544,12 +630,13 @@ abstract class UnixTime extends BinaryExpression with
ExpectsInputTypes {
}""")
}
case StringType =>
- val sdf = classOf[SimpleDateFormat].getName
+ val tz = ctx.addReferenceMinorObj(timeZone)
+ val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
nullSafeCodeGen(ctx, ev, (string, format) => {
s"""
try {
- ${ev.value} =
- (new
$sdf($format.toString())).parse($string.toString()).getTime() / 1000L;
+ ${ev.value} = $dtu.newDateFormat($format.toString(), $tz)
+ .parse($string.toString()).getTime() / 1000L;
} catch (java.lang.IllegalArgumentException e) {
${ev.isNull} = true;
} catch (java.text.ParseException e) {
@@ -567,6 +654,7 @@ abstract class UnixTime extends BinaryExpression with
ExpectsInputTypes {
${ev.value} = ${eval1.value} / 1000000L;
}""")
case DateType =>
+ val tz = ctx.addReferenceMinorObj(timeZone)
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
val eval1 = left.genCode(ctx)
ev.copy(code = s"""
@@ -574,7 +662,7 @@ abstract class UnixTime extends BinaryExpression with
ExpectsInputTypes {
boolean ${ev.isNull} = ${eval1.isNull};
${ctx.javaType(dataType)} ${ev.value} =
${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
- ${ev.value} = $dtu.daysToMillis(${eval1.value}) / 1000L;
+ ${ev.value} = $dtu.daysToMillis(${eval1.value}, $tz) / 1000L;
}""")
}
}
@@ -593,8 +681,10 @@ abstract class UnixTime extends BinaryExpression with
ExpectsInputTypes {
> SELECT _FUNC_(0, 'yyyy-MM-dd HH:mm:ss');
1970-01-01 00:00:00
""")
-case class FromUnixTime(sec: Expression, format: Expression)
- extends BinaryExpression with ImplicitCastInputTypes {
+case class FromUnixTime(sec: Expression, format: Expression, timeZoneId:
Option[String] = None)
+ extends BinaryExpression with TimeZoneAwareExpression with
ImplicitCastInputTypes {
+
+ def this(sec: Expression, format: Expression) = this(sec, format, None)
override def left: Expression = sec
override def right: Expression = format
@@ -610,9 +700,16 @@ case class FromUnixTime(sec: Expression, format:
Expression)
override def inputTypes: Seq[AbstractDataType] = Seq(LongType, StringType)
+ override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
+ copy(timeZoneId = Option(timeZoneId))
+
private lazy val constFormat: UTF8String =
right.eval().asInstanceOf[UTF8String]
- private lazy val formatter: SimpleDateFormat =
- Try(new SimpleDateFormat(constFormat.toString, Locale.US)).getOrElse(null)
+ private lazy val formatter: DateFormat =
+ try {
+ DateTimeUtils.newDateFormat(constFormat.toString, timeZone)
+ } catch {
+ case NonFatal(_) => null
+ }
override def eval(input: InternalRow): Any = {
val time = left.eval(input)
@@ -623,30 +720,36 @@ case class FromUnixTime(sec: Expression, format:
Expression)
if (constFormat == null || formatter == null) {
null
} else {
- Try(UTF8String.fromString(formatter.format(
- new java.util.Date(time.asInstanceOf[Long] *
1000L)))).getOrElse(null)
+ try {
+ UTF8String.fromString(formatter.format(
+ new java.util.Date(time.asInstanceOf[Long] * 1000L)))
+ } catch {
+ case NonFatal(_) => null
+ }
}
} else {
val f = format.eval(input)
if (f == null) {
null
} else {
- Try(
- UTF8String.fromString(new SimpleDateFormat(f.toString, Locale.US).
- format(new java.util.Date(time.asInstanceOf[Long] * 1000L)))
- ).getOrElse(null)
+ try {
+ UTF8String.fromString(DateTimeUtils.newDateFormat(f.toString,
timeZone)
+ .format(new java.util.Date(time.asInstanceOf[Long] * 1000L)))
+ } catch {
+ case NonFatal(_) => null
+ }
}
}
}
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- val sdf = classOf[SimpleDateFormat].getName
+ val df = classOf[DateFormat].getName
if (format.foldable) {
if (formatter == null) {
ExprCode("", "true", "(UTF8String) null")
} else {
- val formatterName = ctx.addReferenceObj("formatter", formatter, sdf)
+ val formatterName = ctx.addReferenceObj("formatter", formatter, df)
val t = left.genCode(ctx)
ev.copy(code = s"""
${t.code}
@@ -662,14 +765,16 @@ case class FromUnixTime(sec: Expression, format:
Expression)
}""")
}
} else {
+ val tz = ctx.addReferenceMinorObj(timeZone)
+ val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
nullSafeCodeGen(ctx, ev, (seconds, f) => {
s"""
try {
- ${ev.value} = UTF8String.fromString((new $sdf($f.toString())).format(
+ ${ev.value} =
UTF8String.fromString($dtu.newDateFormat($f.toString(), $tz).format(
new java.util.Date($seconds * 1000L)));
} catch (java.lang.IllegalArgumentException e) {
${ev.isNull} = true;
- }""".stripMargin
+ }"""
})
}
}
@@ -776,8 +881,10 @@ case class NextDay(startDate: Expression, dayOfWeek:
Expression)
/**
* Adds an interval to timestamp.
*/
-case class TimeAdd(start: Expression, interval: Expression)
- extends BinaryExpression with ImplicitCastInputTypes {
+case class TimeAdd(start: Expression, interval: Expression, timeZoneId:
Option[String] = None)
+ extends BinaryExpression with TimeZoneAwareExpression with
ImplicitCastInputTypes {
+
+ def this(start: Expression, interval: Expression) = this(start, interval,
None)
override def left: Expression = start
override def right: Expression = interval
@@ -788,16 +895,20 @@ case class TimeAdd(start: Expression, interval:
Expression)
override def dataType: DataType = TimestampType
+ override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
+ copy(timeZoneId = Option(timeZoneId))
+
override def nullSafeEval(start: Any, interval: Any): Any = {
val itvl = interval.asInstanceOf[CalendarInterval]
DateTimeUtils.timestampAddInterval(
- start.asInstanceOf[Long], itvl.months, itvl.microseconds)
+ start.asInstanceOf[Long], itvl.months, itvl.microseconds, timeZone)
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val tz = ctx.addReferenceMinorObj(timeZone)
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
defineCodeGen(ctx, ev, (sd, i) => {
- s"""$dtu.timestampAddInterval($sd, $i.months, $i.microseconds)"""
+ s"""$dtu.timestampAddInterval($sd, $i.months, $i.microseconds, $tz)"""
})
}
}
@@ -863,8 +974,10 @@ case class FromUTCTimestamp(left: Expression, right:
Expression)
/**
* Subtracts an interval from timestamp.
*/
-case class TimeSub(start: Expression, interval: Expression)
- extends BinaryExpression with ImplicitCastInputTypes {
+case class TimeSub(start: Expression, interval: Expression, timeZoneId:
Option[String] = None)
+ extends BinaryExpression with TimeZoneAwareExpression with
ImplicitCastInputTypes {
+
+ def this(start: Expression, interval: Expression) = this(start, interval,
None)
override def left: Expression = start
override def right: Expression = interval
@@ -875,16 +988,20 @@ case class TimeSub(start: Expression, interval:
Expression)
override def dataType: DataType = TimestampType
+ override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
+ copy(timeZoneId = Option(timeZoneId))
+
override def nullSafeEval(start: Any, interval: Any): Any = {
val itvl = interval.asInstanceOf[CalendarInterval]
DateTimeUtils.timestampAddInterval(
- start.asInstanceOf[Long], 0 - itvl.months, 0 - itvl.microseconds)
+ start.asInstanceOf[Long], 0 - itvl.months, 0 - itvl.microseconds,
timeZone)
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val tz = ctx.addReferenceMinorObj(timeZone)
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
defineCodeGen(ctx, ev, (sd, i) => {
- s"""$dtu.timestampAddInterval($sd, 0 - $i.months, 0 - $i.microseconds)"""
+ s"""$dtu.timestampAddInterval($sd, 0 - $i.months, 0 - $i.microseconds,
$tz)"""
})
}
}
@@ -937,8 +1054,10 @@ case class AddMonths(startDate: Expression, numMonths:
Expression)
3.94959677
""")
// scalastyle:on line.size.limit
-case class MonthsBetween(date1: Expression, date2: Expression)
- extends BinaryExpression with ImplicitCastInputTypes {
+case class MonthsBetween(date1: Expression, date2: Expression, timeZoneId:
Option[String] = None)
+ extends BinaryExpression with TimeZoneAwareExpression with
ImplicitCastInputTypes {
+
+ def this(date1: Expression, date2: Expression) = this(date1, date2, None)
override def left: Expression = date1
override def right: Expression = date2
@@ -947,14 +1066,18 @@ case class MonthsBetween(date1: Expression, date2:
Expression)
override def dataType: DataType = DoubleType
+ override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
+ copy(timeZoneId = Option(timeZoneId))
+
override def nullSafeEval(t1: Any, t2: Any): Any = {
- DateTimeUtils.monthsBetween(t1.asInstanceOf[Long], t2.asInstanceOf[Long])
+ DateTimeUtils.monthsBetween(t1.asInstanceOf[Long], t2.asInstanceOf[Long],
timeZone)
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val tz = ctx.addReferenceMinorObj(timeZone)
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
defineCodeGen(ctx, ev, (l, r) => {
- s"""$dtu.monthsBetween($l, $r)"""
+ s"""$dtu.monthsBetween($l, $r, $tz)"""
})
}
http://git-wip-us.apache.org/repos/asf/spark/blob/2969fb43/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 20b3898..55d37cc 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -94,7 +94,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog,
conf: CatalystConf)
CombineLimits,
CombineUnions,
// Constant folding and strength reduction
- NullPropagation,
+ NullPropagation(conf),
FoldablePropagation,
OptimizeIn(conf),
ConstantFolding,
@@ -114,7 +114,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog,
conf: CatalystConf)
Batch("Check Cartesian Products", Once,
CheckCartesianProducts(conf)) ::
Batch("Decimal Optimizations", fixedPoint,
- DecimalAggregates) ::
+ DecimalAggregates(conf)) ::
Batch("Typed Filter Optimization", fixedPoint,
CombineTypedFilters) ::
Batch("LocalRelation", fixedPoint,
@@ -1026,7 +1026,7 @@ case class CheckCartesianProducts(conf: CatalystConf)
* This uses the same rules for increasing the precision and scale of the
output as
* [[org.apache.spark.sql.catalyst.analysis.DecimalPrecision]].
*/
-object DecimalAggregates extends Rule[LogicalPlan] {
+case class DecimalAggregates(conf: CatalystConf) extends Rule[LogicalPlan] {
import Decimal.MAX_LONG_DIGITS
/** Maximum number of decimal digits representable precisely in a Double */
@@ -1044,7 +1044,7 @@ object DecimalAggregates extends Rule[LogicalPlan] {
we.copy(windowFunction = ae.copy(aggregateFunction =
Average(UnscaledValue(e))))
Cast(
Divide(newAggExpr, Literal.create(math.pow(10.0, scale),
DoubleType)),
- DecimalType(prec + 4, scale + 4))
+ DecimalType(prec + 4, scale + 4),
Option(conf.sessionLocalTimeZone))
case _ => we
}
@@ -1056,7 +1056,7 @@ object DecimalAggregates extends Rule[LogicalPlan] {
val newAggExpr = ae.copy(aggregateFunction =
Average(UnscaledValue(e)))
Cast(
Divide(newAggExpr, Literal.create(math.pow(10.0, scale),
DoubleType)),
- DecimalType(prec + 4, scale + 4))
+ DecimalType(prec + 4, scale + 4),
Option(conf.sessionLocalTimeZone))
case _ => ae
}
http://git-wip-us.apache.org/repos/asf/spark/blob/2969fb43/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index 949ccdc..5bfc0ce 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -340,7 +340,7 @@ object LikeSimplification extends Rule[LogicalPlan] {
* equivalent [[Literal]] values. This rule is more specific with
* Null value propagation from bottom to top of the expression tree.
*/
-object NullPropagation extends Rule[LogicalPlan] {
+case class NullPropagation(conf: CatalystConf) extends Rule[LogicalPlan] {
private def nonNullLiteral(e: Expression): Boolean = e match {
case Literal(null, _) => false
case _ => true
@@ -348,10 +348,10 @@ object NullPropagation extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsUp {
- case e @ WindowExpression(Cast(Literal(0L, _), _), _) =>
- Cast(Literal(0L), e.dataType)
+ case e @ WindowExpression(Cast(Literal(0L, _), _, _), _) =>
+ Cast(Literal(0L), e.dataType, Option(conf.sessionLocalTimeZone))
case e @ AggregateExpression(Count(exprs), _, _, _) if
!exprs.exists(nonNullLiteral) =>
- Cast(Literal(0L), e.dataType)
+ Cast(Literal(0L), e.dataType, Option(conf.sessionLocalTimeZone))
case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType)
case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType)
case e @ GetArrayItem(Literal(null, _), _) => Literal.create(null,
e.dataType)
@@ -518,8 +518,8 @@ case class OptimizeCodegen(conf: CatalystConf) extends
Rule[LogicalPlan] {
*/
object SimplifyCasts extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
- case Cast(e, dataType) if e.dataType == dataType => e
- case c @ Cast(e, dataType) => (e.dataType, dataType) match {
+ case Cast(e, dataType, _) if e.dataType == dataType => e
+ case c @ Cast(e, dataType, _) => (e.dataType, dataType) match {
case (ArrayType(from, false), ArrayType(to, true)) if from == to => e
case (MapType(fromKey, fromValue, false), MapType(toKey, toValue, true))
if fromKey == toKey && fromValue == toValue => e
http://git-wip-us.apache.org/repos/asf/spark/blob/2969fb43/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala
index f20eb95..89e1dc9 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala
@@ -17,10 +17,15 @@
package org.apache.spark.sql.catalyst.optimizer
+import java.util.TimeZone
+
+import scala.collection.mutable
+
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
@@ -41,13 +46,18 @@ object ReplaceExpressions extends Rule[LogicalPlan] {
*/
object ComputeCurrentTime extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = {
- val dateExpr = CurrentDate()
+ val currentDates = mutable.Map.empty[String, Literal]
val timeExpr = CurrentTimestamp()
- val currentDate = Literal.create(dateExpr.eval(EmptyRow),
dateExpr.dataType)
- val currentTime = Literal.create(timeExpr.eval(EmptyRow),
timeExpr.dataType)
+ val timestamp = timeExpr.eval(EmptyRow).asInstanceOf[Long]
+ val currentTime = Literal.create(timestamp, timeExpr.dataType)
plan transformAllExpressions {
- case CurrentDate() => currentDate
+ case CurrentDate(Some(timeZoneId)) =>
+ currentDates.getOrElseUpdate(timeZoneId, {
+ Literal.create(
+ DateTimeUtils.millisToDays(timestamp / 1000L,
TimeZone.getTimeZone(timeZoneId)),
+ DateType)
+ })
case CurrentTimestamp() => currentTime
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/2969fb43/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
index a96a3b7..af70efb 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
@@ -60,7 +60,7 @@ object DateTimeUtils {
final val TimeZoneGMT = TimeZone.getTimeZone("GMT")
final val MonthOf31Days = Set(1, 3, 5, 7, 8, 10, 12)
- @transient lazy val defaultTimeZone = TimeZone.getDefault
+ def defaultTimeZone(): TimeZone = TimeZone.getDefault()
// Reuse the Calendar object in each thread as it is expensive to create in
each method call.
private val threadLocalGmtCalendar = new ThreadLocal[Calendar] {
@@ -69,20 +69,19 @@ object DateTimeUtils {
}
}
- // Java TimeZone has no mention of thread safety. Use thread local instance
to be safe.
- private val threadLocalLocalTimeZone = new ThreadLocal[TimeZone] {
- override protected def initialValue: TimeZone = {
- Calendar.getInstance.getTimeZone
- }
- }
-
// `SimpleDateFormat` is not thread-safe.
- val threadLocalTimestampFormat = new ThreadLocal[DateFormat] {
+ private val threadLocalTimestampFormat = new ThreadLocal[DateFormat] {
override def initialValue(): SimpleDateFormat = {
new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US)
}
}
+ def getThreadLocalTimestampFormat(timeZone: TimeZone): DateFormat = {
+ val sdf = threadLocalTimestampFormat.get()
+ sdf.setTimeZone(timeZone)
+ sdf
+ }
+
// `SimpleDateFormat` is not thread-safe.
private val threadLocalDateFormat = new ThreadLocal[DateFormat] {
override def initialValue(): SimpleDateFormat = {
@@ -90,28 +89,54 @@ object DateTimeUtils {
}
}
+ def getThreadLocalDateFormat(): DateFormat = {
+ val sdf = threadLocalDateFormat.get()
+ sdf.setTimeZone(defaultTimeZone())
+ sdf
+ }
+
+ def newDateFormat(formatString: String, timeZone: TimeZone): DateFormat = {
+ val sdf = new SimpleDateFormat(formatString, Locale.US)
+ sdf.setTimeZone(timeZone)
+ sdf
+ }
+
// we should use the exact day as Int, for example, (year, month, day) -> day
def millisToDays(millisUtc: Long): SQLDate = {
+ millisToDays(millisUtc, defaultTimeZone())
+ }
+
+ def millisToDays(millisUtc: Long, timeZone: TimeZone): SQLDate = {
// SPARK-6785: use Math.floor so negative number of days (dates before
1970)
// will correctly work as input for function toJavaDate(Int)
- val millisLocal = millisUtc +
threadLocalLocalTimeZone.get().getOffset(millisUtc)
+ val millisLocal = millisUtc + timeZone.getOffset(millisUtc)
Math.floor(millisLocal.toDouble / MILLIS_PER_DAY).toInt
}
// reverse of millisToDays
def daysToMillis(days: SQLDate): Long = {
+ daysToMillis(days, defaultTimeZone())
+ }
+
+ def daysToMillis(days: SQLDate, timeZone: TimeZone): Long = {
val millisLocal = days.toLong * MILLIS_PER_DAY
- millisLocal - getOffsetFromLocalMillis(millisLocal,
threadLocalLocalTimeZone.get())
+ millisLocal - getOffsetFromLocalMillis(millisLocal, timeZone)
}
def dateToString(days: SQLDate): String =
- threadLocalDateFormat.get.format(toJavaDate(days))
+ getThreadLocalDateFormat.format(toJavaDate(days))
// Converts Timestamp to string according to Hive TimestampWritable
convention.
def timestampToString(us: SQLTimestamp): String = {
+ timestampToString(us, defaultTimeZone())
+ }
+
+ // Converts Timestamp to string according to Hive TimestampWritable
convention.
+ def timestampToString(us: SQLTimestamp, timeZone: TimeZone): String = {
val ts = toJavaTimestamp(us)
val timestampString = ts.toString
- val formatted = threadLocalTimestampFormat.get.format(ts)
+ val timestampFormat = getThreadLocalTimestampFormat(timeZone)
+ val formatted = timestampFormat.format(ts)
if (timestampString.length > 19 && timestampString.substring(19) != ".0") {
formatted + timestampString.substring(19)
@@ -233,10 +258,14 @@ object DateTimeUtils {
* `T[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us]+[h]h:[m]m`
*/
def stringToTimestamp(s: UTF8String): Option[SQLTimestamp] = {
+ stringToTimestamp(s, defaultTimeZone())
+ }
+
+ def stringToTimestamp(s: UTF8String, timeZone: TimeZone):
Option[SQLTimestamp] = {
if (s == null) {
return None
}
- var timeZone: Option[Byte] = None
+ var tz: Option[Byte] = None
val segments: Array[Int] = Array[Int](1, 1, 1, 0, 0, 0, 0, 0, 0)
var i = 0
var currentSegmentValue = 0
@@ -289,12 +318,12 @@ object DateTimeUtils {
segments(i) = currentSegmentValue
currentSegmentValue = 0
i += 1
- timeZone = Some(43)
+ tz = Some(43)
} else if (b == '-' || b == '+') {
segments(i) = currentSegmentValue
currentSegmentValue = 0
i += 1
- timeZone = Some(b)
+ tz = Some(b)
} else if (b == '.' && i == 5) {
segments(i) = currentSegmentValue
currentSegmentValue = 0
@@ -349,11 +378,11 @@ object DateTimeUtils {
return None
}
- val c = if (timeZone.isEmpty) {
- Calendar.getInstance()
+ val c = if (tz.isEmpty) {
+ Calendar.getInstance(timeZone)
} else {
Calendar.getInstance(
-
TimeZone.getTimeZone(f"GMT${timeZone.get.toChar}${segments(7)}%02d:${segments(8)}%02d"))
+
TimeZone.getTimeZone(f"GMT${tz.get.toChar}${segments(7)}%02d:${segments(8)}%02d"))
}
c.set(Calendar.MILLISECOND, 0)
@@ -452,7 +481,11 @@ object DateTimeUtils {
}
private def localTimestamp(microsec: SQLTimestamp): SQLTimestamp = {
- absoluteMicroSecond(microsec) + defaultTimeZone.getOffset(microsec / 1000)
* 1000L
+ localTimestamp(microsec, defaultTimeZone())
+ }
+
+ private def localTimestamp(microsec: SQLTimestamp, timeZone: TimeZone):
SQLTimestamp = {
+ absoluteMicroSecond(microsec) + timeZone.getOffset(microsec / 1000) * 1000L
}
/**
@@ -463,6 +496,13 @@ object DateTimeUtils {
}
/**
+ * Returns the hour value of a given timestamp value. The timestamp is
expressed in microseconds.
+ */
+ def getHours(microsec: SQLTimestamp, timeZone: TimeZone): Int = {
+ ((localTimestamp(microsec, timeZone) / MICROS_PER_SECOND / 3600) %
24).toInt
+ }
+
+ /**
* Returns the minute value of a given timestamp value. The timestamp is
expressed in
* microseconds.
*/
@@ -471,6 +511,14 @@ object DateTimeUtils {
}
/**
+ * Returns the minute value of a given timestamp value. The timestamp is
expressed in
+ * microseconds.
+ */
+ def getMinutes(microsec: SQLTimestamp, timeZone: TimeZone): Int = {
+ ((localTimestamp(microsec, timeZone) / MICROS_PER_SECOND / 60) % 60).toInt
+ }
+
+ /**
* Returns the second value of a given timestamp value. The timestamp is
expressed in
* microseconds.
*/
@@ -478,6 +526,14 @@ object DateTimeUtils {
((localTimestamp(microsec) / MICROS_PER_SECOND) % 60).toInt
}
+ /**
+ * Returns the second value of a given timestamp value. The timestamp is
expressed in
+ * microseconds.
+ */
+ def getSeconds(microsec: SQLTimestamp, timeZone: TimeZone): Int = {
+ ((localTimestamp(microsec, timeZone) / MICROS_PER_SECOND) % 60).toInt
+ }
+
private[this] def isLeapYear(year: Int): Boolean = {
(year % 4) == 0 && ((year % 100) != 0 || (year % 400) == 0)
}
@@ -742,9 +798,23 @@ object DateTimeUtils {
* Returns a timestamp value, expressed in microseconds since 1.1.1970
00:00:00.
*/
def timestampAddInterval(start: SQLTimestamp, months: Int, microseconds:
Long): SQLTimestamp = {
- val days = millisToDays(start / 1000L)
+ timestampAddInterval(start, months, microseconds, defaultTimeZone())
+ }
+
+ /**
+ * Add timestamp and full interval.
+ * Returns a timestamp value, expressed in microseconds since 1.1.1970
00:00:00.
+ */
+ def timestampAddInterval(
+ start: SQLTimestamp,
+ months: Int,
+ microseconds: Long,
+ timeZone: TimeZone): SQLTimestamp = {
+ val days = millisToDays(start / 1000L, timeZone)
val newDays = dateAddMonths(days, months)
- daysToMillis(newDays) * 1000L + start - daysToMillis(days) * 1000L +
microseconds
+ start +
+ daysToMillis(newDays, timeZone) * 1000L - daysToMillis(days, timeZone) *
1000L +
+ microseconds
}
/**
@@ -758,10 +828,24 @@ object DateTimeUtils {
* 8 digits.
*/
def monthsBetween(time1: SQLTimestamp, time2: SQLTimestamp): Double = {
+ monthsBetween(time1, time2, defaultTimeZone())
+ }
+
+ /**
+ * Returns number of months between time1 and time2. time1 and time2 are
expressed in
+ * microseconds since 1.1.1970.
+ *
+ * If time1 and time2 having the same day of month, or both are the last day
of month,
+ * it returns an integer (time under a day will be ignored).
+ *
+ * Otherwise, the difference is calculated based on 31 days per month, and
rounding to
+ * 8 digits.
+ */
+ def monthsBetween(time1: SQLTimestamp, time2: SQLTimestamp, timeZone:
TimeZone): Double = {
val millis1 = time1 / 1000L
val millis2 = time2 / 1000L
- val date1 = millisToDays(millis1)
- val date2 = millisToDays(millis2)
+ val date1 = millisToDays(millis1, timeZone)
+ val date2 = millisToDays(millis2, timeZone)
val (year1, monthInYear1, dayInMonth1, daysToMonthEnd1) = splitDate(date1)
val (year2, monthInYear2, dayInMonth2, daysToMonthEnd2) = splitDate(date2)
@@ -772,8 +856,8 @@ object DateTimeUtils {
return (months1 - months2).toDouble
}
// milliseconds is enough for 8 digits precision on the right side
- val timeInDay1 = millis1 - daysToMillis(date1)
- val timeInDay2 = millis2 - daysToMillis(date2)
+ val timeInDay1 = millis1 - daysToMillis(date1, timeZone)
+ val timeInDay2 = millis2 - daysToMillis(date2, timeZone)
val timesBetween = (timeInDay1 - timeInDay2).toDouble / MILLIS_PER_DAY
val diff = (months1 - months2).toDouble + (dayInMonth1 - dayInMonth2 +
timesBetween) / 31.0
// rounding to 8 digits
@@ -896,7 +980,7 @@ object DateTimeUtils {
*/
def convertTz(ts: SQLTimestamp, fromZone: TimeZone, toZone: TimeZone):
SQLTimestamp = {
// We always use local timezone to parse or format a timestamp
- val localZone = threadLocalLocalTimeZone.get()
+ val localZone = defaultTimeZone()
val utcTs = if (fromZone.getID == localZone.getID) {
ts
} else {
@@ -907,9 +991,9 @@ object DateTimeUtils {
if (toZone.getID == localZone.getID) {
utcTs
} else {
- val localTs2 = utcTs + toZone.getOffset(utcTs / 1000L) * 1000L // in
toZone
+ val localTs = utcTs + toZone.getOffset(utcTs / 1000L) * 1000L // in
toZone
// treat it as local timezone, convert to UTC (we could get the expected
human time back)
- localTs2 - getOffsetFromLocalMillis(localTs2 / 1000L, localZone) * 1000L
+ localTs - getOffsetFromLocalMillis(localTs / 1000L, localZone) * 1000L
}
}
@@ -934,7 +1018,6 @@ object DateTimeUtils {
*/
private[util] def resetThreadLocals(): Unit = {
threadLocalGmtCalendar.remove()
- threadLocalLocalTimeZone.remove()
threadLocalTimestampFormat.remove()
threadLocalDateFormat.remove()
}
http://git-wip-us.apache.org/repos/asf/spark/blob/2969fb43/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 817de48..81a97dc 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.analysis
+import java.util.TimeZone
+
import org.scalatest.ShouldMatchers
import org.apache.spark.sql.catalyst.{SimpleCatalystConf, TableIdentifier}
@@ -258,7 +260,8 @@ class AnalysisSuite extends AnalysisTest with
ShouldMatchers {
val c = testRelation2.output(2)
val plan = testRelation2.select('c).orderBy(Floor('a).asc)
- val expected = testRelation2.select(c,
a).orderBy(Floor(a.cast(DoubleType)).asc).select(c)
+ val expected = testRelation2.select(c, a)
+ .orderBy(Floor(Cast(a, DoubleType,
Option(TimeZone.getDefault().getID))).asc).select(c)
checkAnalysis(plan, expected)
}
http://git-wip-us.apache.org/repos/asf/spark/blob/2969fb43/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala
index 2a0205b..553b159 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.analysis
+import java.util.TimeZone
+
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
@@ -32,7 +34,7 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest {
lazy val unresolved_c = UnresolvedAttribute("c")
lazy val gid = 'spark_grouping_id.int.withNullability(false)
lazy val hive_gid = 'grouping__id.int.withNullability(false)
- lazy val grouping_a = Cast(ShiftRight(gid, 1) & 1, ByteType)
+ lazy val grouping_a = Cast(ShiftRight(gid, 1) & 1, ByteType,
Option(TimeZone.getDefault().getID))
lazy val nulInt = Literal(null, IntegerType)
lazy val nulStr = Literal(null, StringType)
lazy val r1 = LocalRelation(a, b, c)
@@ -213,7 +215,8 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest {
val originalPlan = Filter(Grouping(unresolved_a) === 0,
GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a,
unresolved_b)),
Seq(unresolved_a, unresolved_b), r1, Seq(unresolved_a, unresolved_b)))
- val expected = Project(Seq(a, b), Filter(Cast(grouping_a, IntegerType) ===
0,
+ val expected = Project(Seq(a, b),
+ Filter(Cast(grouping_a, IntegerType,
Option(TimeZone.getDefault().getID)) === 0,
Aggregate(Seq(a, b, gid),
Seq(a, b, gid),
Expand(
http://git-wip-us.apache.org/repos/asf/spark/blob/2969fb43/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
index b748595..8eccadb 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
@@ -18,12 +18,14 @@
package org.apache.spark.sql.catalyst.expressions
import java.sql.{Date, Timestamp}
-import java.util.{Calendar, TimeZone}
+import java.util.{Calendar, Locale, TimeZone}
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.catalyst.util.DateTimeUtils.TimeZoneGMT
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -32,10 +34,10 @@ import org.apache.spark.unsafe.types.UTF8String
*/
class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
- private def cast(v: Any, targetType: DataType): Cast = {
+ private def cast(v: Any, targetType: DataType, timeZoneId: Option[String] =
None): Cast = {
v match {
- case lit: Expression => Cast(lit, targetType)
- case _ => Cast(Literal(v), targetType)
+ case lit: Expression => Cast(lit, targetType, timeZoneId)
+ case _ => Cast(Literal(v), targetType, timeZoneId)
}
}
@@ -45,7 +47,7 @@ class CastSuite extends SparkFunSuite with
ExpressionEvalHelper {
}
private def checkNullCast(from: DataType, to: DataType): Unit = {
- checkEvaluation(Cast(Literal.create(null, from), to), null)
+ checkEvaluation(cast(Literal.create(null, from), to, Option("GMT")), null)
}
test("null cast") {
@@ -107,108 +109,98 @@ class CastSuite extends SparkFunSuite with
ExpressionEvalHelper {
}
test("cast string to timestamp") {
- checkEvaluation(Cast(Literal("123"), TimestampType), null)
-
- var c = Calendar.getInstance()
- c.set(2015, 0, 1, 0, 0, 0)
- c.set(Calendar.MILLISECOND, 0)
- checkEvaluation(Cast(Literal("2015"), TimestampType),
- new Timestamp(c.getTimeInMillis))
- c = Calendar.getInstance()
- c.set(2015, 2, 1, 0, 0, 0)
- c.set(Calendar.MILLISECOND, 0)
- checkEvaluation(Cast(Literal("2015-03"), TimestampType),
- new Timestamp(c.getTimeInMillis))
- c = Calendar.getInstance()
- c.set(2015, 2, 18, 0, 0, 0)
- c.set(Calendar.MILLISECOND, 0)
- checkEvaluation(Cast(Literal("2015-03-18"), TimestampType),
- new Timestamp(c.getTimeInMillis))
- checkEvaluation(Cast(Literal("2015-03-18 "), TimestampType),
- new Timestamp(c.getTimeInMillis))
- checkEvaluation(Cast(Literal("2015-03-18T"), TimestampType),
- new Timestamp(c.getTimeInMillis))
-
- c = Calendar.getInstance()
- c.set(2015, 2, 18, 12, 3, 17)
- c.set(Calendar.MILLISECOND, 0)
- checkEvaluation(Cast(Literal("2015-03-18 12:03:17"), TimestampType),
- new Timestamp(c.getTimeInMillis))
- checkEvaluation(Cast(Literal("2015-03-18T12:03:17"), TimestampType),
- new Timestamp(c.getTimeInMillis))
-
- c = Calendar.getInstance(TimeZone.getTimeZone("UTC"))
- c.set(2015, 2, 18, 12, 3, 17)
- c.set(Calendar.MILLISECOND, 0)
- checkEvaluation(Cast(Literal("2015-03-18T12:03:17Z"), TimestampType),
- new Timestamp(c.getTimeInMillis))
- checkEvaluation(Cast(Literal("2015-03-18 12:03:17Z"), TimestampType),
- new Timestamp(c.getTimeInMillis))
-
- c = Calendar.getInstance(TimeZone.getTimeZone("GMT-01:00"))
- c.set(2015, 2, 18, 12, 3, 17)
- c.set(Calendar.MILLISECOND, 0)
- checkEvaluation(Cast(Literal("2015-03-18T12:03:17-1:0"), TimestampType),
- new Timestamp(c.getTimeInMillis))
- checkEvaluation(Cast(Literal("2015-03-18T12:03:17-01:00"), TimestampType),
- new Timestamp(c.getTimeInMillis))
-
- c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30"))
- c.set(2015, 2, 18, 12, 3, 17)
- c.set(Calendar.MILLISECOND, 0)
- checkEvaluation(Cast(Literal("2015-03-18T12:03:17+07:30"), TimestampType),
- new Timestamp(c.getTimeInMillis))
-
- c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:03"))
- c.set(2015, 2, 18, 12, 3, 17)
- c.set(Calendar.MILLISECOND, 0)
- checkEvaluation(Cast(Literal("2015-03-18T12:03:17+7:3"), TimestampType),
- new Timestamp(c.getTimeInMillis))
-
- c = Calendar.getInstance()
- c.set(2015, 2, 18, 12, 3, 17)
- c.set(Calendar.MILLISECOND, 123)
- checkEvaluation(Cast(Literal("2015-03-18 12:03:17.123"), TimestampType),
- new Timestamp(c.getTimeInMillis))
- checkEvaluation(Cast(Literal("2015-03-18T12:03:17.123"), TimestampType),
- new Timestamp(c.getTimeInMillis))
-
- c = Calendar.getInstance(TimeZone.getTimeZone("UTC"))
- c.set(2015, 2, 18, 12, 3, 17)
- c.set(Calendar.MILLISECOND, 456)
- checkEvaluation(Cast(Literal("2015-03-18T12:03:17.456Z"), TimestampType),
- new Timestamp(c.getTimeInMillis))
- checkEvaluation(Cast(Literal("2015-03-18 12:03:17.456Z"), TimestampType),
- new Timestamp(c.getTimeInMillis))
-
- c = Calendar.getInstance(TimeZone.getTimeZone("GMT-01:00"))
- c.set(2015, 2, 18, 12, 3, 17)
- c.set(Calendar.MILLISECOND, 123)
- checkEvaluation(Cast(Literal("2015-03-18T12:03:17.123-1:0"),
TimestampType),
- new Timestamp(c.getTimeInMillis))
- checkEvaluation(Cast(Literal("2015-03-18T12:03:17.123-01:00"),
TimestampType),
- new Timestamp(c.getTimeInMillis))
-
- c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30"))
- c.set(2015, 2, 18, 12, 3, 17)
- c.set(Calendar.MILLISECOND, 123)
- checkEvaluation(Cast(Literal("2015-03-18T12:03:17.123+07:30"),
TimestampType),
- new Timestamp(c.getTimeInMillis))
-
- c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:03"))
- c.set(2015, 2, 18, 12, 3, 17)
- c.set(Calendar.MILLISECOND, 123)
- checkEvaluation(Cast(Literal("2015-03-18T12:03:17.123+7:3"),
TimestampType),
- new Timestamp(c.getTimeInMillis))
-
- checkEvaluation(Cast(Literal("2015-03-18 123142"), TimestampType), null)
- checkEvaluation(Cast(Literal("2015-03-18T123123"), TimestampType), null)
- checkEvaluation(Cast(Literal("2015-03-18X"), TimestampType), null)
- checkEvaluation(Cast(Literal("2015/03/18"), TimestampType), null)
- checkEvaluation(Cast(Literal("2015.03.18"), TimestampType), null)
- checkEvaluation(Cast(Literal("20150318"), TimestampType), null)
- checkEvaluation(Cast(Literal("2015-031-8"), TimestampType), null)
- checkEvaluation(Cast(Literal("2015-03-18T12:03:17-0:70"), TimestampType),
null)
+ for (tz <- ALL_TIMEZONES) {
+ def checkCastStringToTimestamp(str: String, expected: Timestamp): Unit =
{
+ checkEvaluation(cast(Literal(str), TimestampType, Option(tz.getID)),
expected)
+ }
+
+ checkCastStringToTimestamp("123", null)
+
+ var c = Calendar.getInstance(tz)
+ c.set(2015, 0, 1, 0, 0, 0)
+ c.set(Calendar.MILLISECOND, 0)
+ checkCastStringToTimestamp("2015", new Timestamp(c.getTimeInMillis))
+ c = Calendar.getInstance(tz)
+ c.set(2015, 2, 1, 0, 0, 0)
+ c.set(Calendar.MILLISECOND, 0)
+ checkCastStringToTimestamp("2015-03", new Timestamp(c.getTimeInMillis))
+ c = Calendar.getInstance(tz)
+ c.set(2015, 2, 18, 0, 0, 0)
+ c.set(Calendar.MILLISECOND, 0)
+ checkCastStringToTimestamp("2015-03-18", new
Timestamp(c.getTimeInMillis))
+ checkCastStringToTimestamp("2015-03-18 ", new
Timestamp(c.getTimeInMillis))
+ checkCastStringToTimestamp("2015-03-18T", new
Timestamp(c.getTimeInMillis))
+
+ c = Calendar.getInstance(tz)
+ c.set(2015, 2, 18, 12, 3, 17)
+ c.set(Calendar.MILLISECOND, 0)
+ checkCastStringToTimestamp("2015-03-18 12:03:17", new
Timestamp(c.getTimeInMillis))
+ checkCastStringToTimestamp("2015-03-18T12:03:17", new
Timestamp(c.getTimeInMillis))
+
+ // If the string value includes timezone string, it represents the
timestamp string
+ // in the timezone regardless of the timeZoneId parameter.
+ c = Calendar.getInstance(TimeZone.getTimeZone("UTC"))
+ c.set(2015, 2, 18, 12, 3, 17)
+ c.set(Calendar.MILLISECOND, 0)
+ checkCastStringToTimestamp("2015-03-18T12:03:17Z", new
Timestamp(c.getTimeInMillis))
+ checkCastStringToTimestamp("2015-03-18 12:03:17Z", new
Timestamp(c.getTimeInMillis))
+
+ c = Calendar.getInstance(TimeZone.getTimeZone("GMT-01:00"))
+ c.set(2015, 2, 18, 12, 3, 17)
+ c.set(Calendar.MILLISECOND, 0)
+ checkCastStringToTimestamp("2015-03-18T12:03:17-1:0", new
Timestamp(c.getTimeInMillis))
+ checkCastStringToTimestamp("2015-03-18T12:03:17-01:00", new
Timestamp(c.getTimeInMillis))
+
+ c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30"))
+ c.set(2015, 2, 18, 12, 3, 17)
+ c.set(Calendar.MILLISECOND, 0)
+ checkCastStringToTimestamp("2015-03-18T12:03:17+07:30", new
Timestamp(c.getTimeInMillis))
+
+ c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:03"))
+ c.set(2015, 2, 18, 12, 3, 17)
+ c.set(Calendar.MILLISECOND, 0)
+ checkCastStringToTimestamp("2015-03-18T12:03:17+7:3", new
Timestamp(c.getTimeInMillis))
+
+ // tests for the string including milliseconds.
+ c = Calendar.getInstance(tz)
+ c.set(2015, 2, 18, 12, 3, 17)
+ c.set(Calendar.MILLISECOND, 123)
+ checkCastStringToTimestamp("2015-03-18 12:03:17.123", new
Timestamp(c.getTimeInMillis))
+ checkCastStringToTimestamp("2015-03-18T12:03:17.123", new
Timestamp(c.getTimeInMillis))
+
+ // If the string value includes timezone string, it represents the
timestamp string
+ // in the timezone regardless of the timeZoneId parameter.
+ c = Calendar.getInstance(TimeZone.getTimeZone("UTC"))
+ c.set(2015, 2, 18, 12, 3, 17)
+ c.set(Calendar.MILLISECOND, 456)
+ checkCastStringToTimestamp("2015-03-18T12:03:17.456Z", new
Timestamp(c.getTimeInMillis))
+ checkCastStringToTimestamp("2015-03-18 12:03:17.456Z", new
Timestamp(c.getTimeInMillis))
+
+ c = Calendar.getInstance(TimeZone.getTimeZone("GMT-01:00"))
+ c.set(2015, 2, 18, 12, 3, 17)
+ c.set(Calendar.MILLISECOND, 123)
+ checkCastStringToTimestamp("2015-03-18T12:03:17.123-1:0", new
Timestamp(c.getTimeInMillis))
+ checkCastStringToTimestamp("2015-03-18T12:03:17.123-01:00", new
Timestamp(c.getTimeInMillis))
+
+ c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30"))
+ c.set(2015, 2, 18, 12, 3, 17)
+ c.set(Calendar.MILLISECOND, 123)
+ checkCastStringToTimestamp("2015-03-18T12:03:17.123+07:30", new
Timestamp(c.getTimeInMillis))
+
+ c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:03"))
+ c.set(2015, 2, 18, 12, 3, 17)
+ c.set(Calendar.MILLISECOND, 123)
+ checkCastStringToTimestamp("2015-03-18T12:03:17.123+7:3", new
Timestamp(c.getTimeInMillis))
+
+ checkCastStringToTimestamp("2015-03-18 123142", null)
+ checkCastStringToTimestamp("2015-03-18T123123", null)
+ checkCastStringToTimestamp("2015-03-18X", null)
+ checkCastStringToTimestamp("2015/03/18", null)
+ checkCastStringToTimestamp("2015.03.18", null)
+ checkCastStringToTimestamp("20150318", null)
+ checkCastStringToTimestamp("2015-031-8", null)
+ checkCastStringToTimestamp("2015-03-18T12:03:17-0:70", null)
+ }
}
test("cast from int") {
@@ -316,30 +308,43 @@ class CastSuite extends SparkFunSuite with
ExpressionEvalHelper {
val zts = sd + " 00:00:00"
val sts = sd + " 00:00:02"
val nts = sts + ".1"
- val ts = Timestamp.valueOf(nts)
-
- var c = Calendar.getInstance()
- c.set(2015, 2, 8, 2, 30, 0)
- checkEvaluation(cast(cast(new Timestamp(c.getTimeInMillis), StringType),
TimestampType),
- c.getTimeInMillis * 1000)
- c = Calendar.getInstance()
- c.set(2015, 10, 1, 2, 30, 0)
- checkEvaluation(cast(cast(new Timestamp(c.getTimeInMillis), StringType),
TimestampType),
- c.getTimeInMillis * 1000)
+ val ts = withDefaultTimeZone(TimeZoneGMT)(Timestamp.valueOf(nts))
+
+ for (tz <- ALL_TIMEZONES) {
+ val timeZoneId = Option(tz.getID)
+ var c = Calendar.getInstance(TimeZoneGMT)
+ c.set(2015, 2, 8, 2, 30, 0)
+ checkEvaluation(
+ cast(cast(new Timestamp(c.getTimeInMillis), StringType, timeZoneId),
+ TimestampType, timeZoneId),
+ c.getTimeInMillis * 1000)
+ c = Calendar.getInstance(TimeZoneGMT)
+ c.set(2015, 10, 1, 2, 30, 0)
+ checkEvaluation(
+ cast(cast(new Timestamp(c.getTimeInMillis), StringType, timeZoneId),
+ TimestampType, timeZoneId),
+ c.getTimeInMillis * 1000)
+ }
+
+ val gmtId = Option("GMT")
checkEvaluation(cast("abdef", StringType), "abdef")
checkEvaluation(cast("abdef", DecimalType.USER_DEFAULT), null)
- checkEvaluation(cast("abdef", TimestampType), null)
+ checkEvaluation(cast("abdef", TimestampType, gmtId), null)
checkEvaluation(cast("12.65", DecimalType.SYSTEM_DEFAULT), Decimal(12.65))
checkEvaluation(cast(cast(sd, DateType), StringType), sd)
checkEvaluation(cast(cast(d, StringType), DateType), 0)
- checkEvaluation(cast(cast(nts, TimestampType), StringType), nts)
- checkEvaluation(cast(cast(ts, StringType), TimestampType),
DateTimeUtils.fromJavaTimestamp(ts))
+ checkEvaluation(cast(cast(nts, TimestampType, gmtId), StringType, gmtId),
nts)
+ checkEvaluation(
+ cast(cast(ts, StringType, gmtId), TimestampType, gmtId),
+ DateTimeUtils.fromJavaTimestamp(ts))
// all convert to string type to check
- checkEvaluation(cast(cast(cast(nts, TimestampType), DateType),
StringType), sd)
- checkEvaluation(cast(cast(cast(ts, DateType), TimestampType), StringType),
zts)
+ checkEvaluation(cast(cast(cast(nts, TimestampType, gmtId), DateType,
gmtId), StringType), sd)
+ checkEvaluation(
+ cast(cast(cast(ts, DateType, gmtId), TimestampType, gmtId), StringType,
gmtId),
+ zts)
checkEvaluation(cast(cast("abdef", BinaryType), StringType), "abdef")
@@ -351,7 +356,7 @@ class CastSuite extends SparkFunSuite with
ExpressionEvalHelper {
DecimalType.SYSTEM_DEFAULT), LongType), StringType), ShortType),
5.toShort)
checkEvaluation(
- cast(cast(cast(cast(cast(cast("5", TimestampType), ByteType),
+ cast(cast(cast(cast(cast(cast("5", TimestampType, gmtId), ByteType),
DecimalType.SYSTEM_DEFAULT), LongType), StringType), ShortType),
null)
checkEvaluation(cast(cast(cast(cast(cast(cast("5",
DecimalType.SYSTEM_DEFAULT),
@@ -466,7 +471,9 @@ class CastSuite extends SparkFunSuite with
ExpressionEvalHelper {
checkEvaluation(cast(d, DecimalType.SYSTEM_DEFAULT), null)
checkEvaluation(cast(d, DecimalType(10, 2)), null)
checkEvaluation(cast(d, StringType), "1970-01-01")
- checkEvaluation(cast(cast(d, TimestampType), StringType), "1970-01-01
00:00:00")
+
+ val gmtId = Option("GMT")
+ checkEvaluation(cast(cast(d, TimestampType, gmtId), StringType, gmtId),
"1970-01-01 00:00:00")
}
test("cast from timestamp") {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]