This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new d817c9a60f51 [SPARK-47775][SQL] Support remaining scalar types in the
variant spec
d817c9a60f51 is described below
commit d817c9a60f51ef8035c8d2b37a995976ae54aa47
Author: Chenhao Li <[email protected]>
AuthorDate: Wed Apr 10 22:51:17 2024 +0800
[SPARK-47775][SQL] Support remaining scalar types in the variant spec
### What changes were proposed in this pull request?
This PR adds support for the remaining scalar types defined in the variant
spec (DATE, TIMESTAMP, TIMESTAMP_NTZ, FLOAT, BINARY). The current `parse_json`
expression doesn't produce these types, but we need them when we support
casting a corresponding Spark type into the variant type.
### Why are the changes needed?
This PR can be considered as a preparation for the cast-to-variant feature
and will make the latter PR smaller.
### Does this PR introduce _any_ user-facing change?
Yes. Existing variant expressions can decode more variant scalar types.
### How was this patch tested?
Unit tests. We manually construct variant values with these new scalar
types and test the existing variant expressions on them.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #45945 from chenhao-db/support_atomic_types.
Authored-by: Chenhao Li <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../org/apache/spark/unsafe/types/VariantVal.java | 8 +-
.../org/apache/spark/types/variant/Variant.java | 69 ++++++++++++-
.../apache/spark/types/variant/VariantUtil.java | 71 ++++++++++++-
.../spark/sql/catalyst/expressions/Cast.scala | 7 +-
.../expressions/variant/variantExpressions.scala | 84 +++++++++++-----
.../spark/sql/catalyst/json/JacksonGenerator.scala | 2 +-
.../variant/VariantExpressionSuite.scala | 112 +++++++++++++++++++++
7 files changed, 314 insertions(+), 39 deletions(-)
diff --git
a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/VariantVal.java
b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/VariantVal.java
index 652c05daf344..a441bab4ac41 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/VariantVal.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/VariantVal.java
@@ -21,6 +21,8 @@ import org.apache.spark.unsafe.Platform;
import org.apache.spark.types.variant.Variant;
import java.io.Serializable;
+import java.time.ZoneId;
+import java.time.ZoneOffset;
import java.util.Arrays;
/**
@@ -99,13 +101,17 @@ public class VariantVal implements Serializable {
'}';
}
+ public String toJson(ZoneId zoneId) {
+ return new Variant(value, metadata).toJson(zoneId);
+ }
+
/**
* @return A human-readable representation of the Variant value. It is
always a JSON string at
* this moment.
*/
@Override
public String toString() {
- return new Variant(value, metadata).toJson();
+ return toJson(ZoneOffset.UTC);
}
/**
diff --git
a/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java
b/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java
index 8340aadd261f..4aeb2c6e1435 100644
--- a/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java
+++ b/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java
@@ -23,7 +23,16 @@ import com.fasterxml.jackson.core.JsonGenerator;
import java.io.CharArrayWriter;
import java.io.IOException;
import java.math.BigDecimal;
+import java.time.Instant;
+import java.time.LocalDate;
+import java.time.ZoneId;
+import java.time.ZoneOffset;
+import java.time.format.DateTimeFormatter;
+import java.time.format.DateTimeFormatterBuilder;
+import java.time.temporal.ChronoUnit;
import java.util.Arrays;
+import java.util.Base64;
+import java.util.Locale;
import static org.apache.spark.types.variant.VariantUtil.*;
@@ -89,6 +98,16 @@ public final class Variant {
return VariantUtil.getDecimal(value, pos);
}
+ // Get a float value from the variant.
+ public float getFloat() {
+ return VariantUtil.getFloat(value, pos);
+ }
+
+ // Get a binary value from the variant.
+ public byte[] getBinary() {
+ return VariantUtil.getBinary(value, pos);
+ }
+
// Get a string value from the variant.
public String getString() {
return VariantUtil.getString(value, pos);
@@ -188,9 +207,9 @@ public final class Variant {
// Stringify the variant in JSON format.
// Throw `MALFORMED_VARIANT` if the variant is malformed.
- public String toJson() {
+ public String toJson(ZoneId zoneId) {
StringBuilder sb = new StringBuilder();
- toJsonImpl(value, metadata, pos, sb);
+ toJsonImpl(value, metadata, pos, sb, zoneId);
return sb.toString();
}
@@ -208,7 +227,30 @@ public final class Variant {
}
}
- static void toJsonImpl(byte[] value, byte[] metadata, int pos, StringBuilder
sb) {
+ // A simplified and more performant version of `sb.append(escapeJson(str))`.
It is used when we
+ // know `str` doesn't contain any special character that needs escaping.
+ static void appendQuoted(StringBuilder sb, String str) {
+ sb.append('"');
+ sb.append(str);
+ sb.append('"');
+ }
+
+ private static final DateTimeFormatter TIMESTAMP_NTZ_FORMATTER = new
DateTimeFormatterBuilder()
+ .append(DateTimeFormatter.ISO_LOCAL_DATE)
+ .appendLiteral(' ')
+ .append(DateTimeFormatter.ISO_LOCAL_TIME)
+ .toFormatter(Locale.US);
+
+ private static final DateTimeFormatter TIMESTAMP_FORMATTER = new
DateTimeFormatterBuilder()
+ .append(TIMESTAMP_NTZ_FORMATTER)
+ .appendOffset("+HH:MM", "+00:00")
+ .toFormatter(Locale.US);
+
+ private static Instant microsToInstant(long timestamp) {
+ return Instant.EPOCH.plus(timestamp, ChronoUnit.MICROS);
+ }
+
+ static void toJsonImpl(byte[] value, byte[] metadata, int pos, StringBuilder
sb, ZoneId zoneId) {
switch (VariantUtil.getType(value, pos)) {
case OBJECT:
handleObject(value, pos, (size, idSize, offsetSize, idStart,
offsetStart, dataStart) -> {
@@ -220,7 +262,7 @@ public final class Variant {
if (i != 0) sb.append(',');
sb.append(escapeJson(getMetadataKey(metadata, id)));
sb.append(':');
- toJsonImpl(value, metadata, elementPos, sb);
+ toJsonImpl(value, metadata, elementPos, sb, zoneId);
}
sb.append('}');
return null;
@@ -233,7 +275,7 @@ public final class Variant {
int offset = readUnsigned(value, offsetStart + offsetSize * i,
offsetSize);
int elementPos = dataStart + offset;
if (i != 0) sb.append(',');
- toJsonImpl(value, metadata, elementPos, sb);
+ toJsonImpl(value, metadata, elementPos, sb, zoneId);
}
sb.append(']');
return null;
@@ -257,6 +299,23 @@ public final class Variant {
case DECIMAL:
sb.append(VariantUtil.getDecimal(value, pos).toPlainString());
break;
+ case DATE:
+ appendQuoted(sb, LocalDate.ofEpochDay((int) VariantUtil.getLong(value,
pos)).toString());
+ break;
+ case TIMESTAMP:
+ appendQuoted(sb, TIMESTAMP_FORMATTER.format(
+ microsToInstant(VariantUtil.getLong(value, pos)).atZone(zoneId)));
+ break;
+ case TIMESTAMP_NTZ:
+ appendQuoted(sb, TIMESTAMP_NTZ_FORMATTER.format(
+ microsToInstant(VariantUtil.getLong(value,
pos)).atZone(ZoneOffset.UTC)));
+ break;
+ case FLOAT:
+ sb.append(VariantUtil.getFloat(value, pos));
+ break;
+ case BINARY:
+ appendQuoted(sb,
Base64.getEncoder().encodeToString(VariantUtil.getBinary(value, pos)));
+ break;
}
}
}
diff --git
a/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java
b/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java
index 1d579188ccdb..e4e9cc8b4cfa 100644
---
a/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java
+++
b/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java
@@ -23,6 +23,7 @@ import scala.collection.immutable.Map$;
import java.math.BigDecimal;
import java.math.BigInteger;
+import java.util.Arrays;
/**
* This class defines constants related to the variant format and provides
functions for
@@ -101,6 +102,21 @@ public class VariantUtil {
public static final int DECIMAL8 = 9;
// 16-byte decimal. Content is 1-byte scale + 16-byte little-endian signed
integer.
public static final int DECIMAL16 = 10;
+ // Date value. Content is 4-byte little-endian signed integer that
represents the number of days
+ // from the Unix epoch.
+ public static final int DATE = 11;
+ // Timestamp value. Content is 8-byte little-endian signed integer that
represents the number of
+ // microseconds elapsed since the Unix epoch, 1970-01-01 00:00:00 UTC. It is
displayed to users in
+ // their local time zones and may be displayed differently depending on the
execution environment.
+ public static final int TIMESTAMP = 12;
+ // Timestamp_ntz value. It has the same content as `TIMESTAMP` but should
always be interpreted
+ // as if the local time zone is UTC.
+ public static final int TIMESTAMP_NTZ = 13;
+ // 4-byte IEEE float.
+ public static final int FLOAT = 14;
+ // Binary value. The content is (4-byte little-endian unsigned integer
representing the binary
+ // size) + (size bytes of binary content).
+ public static final int BINARY = 15;
// Long string value. The content is (4-byte little-endian unsigned integer
representing the
// string size) + (size bytes of string content).
public static final int LONG_STR = 16;
@@ -212,6 +228,11 @@ public class VariantUtil {
STRING,
DOUBLE,
DECIMAL,
+ DATE,
+ TIMESTAMP,
+ TIMESTAMP_NTZ,
+ FLOAT,
+ BINARY,
}
// Get the value type of variant value `value[pos...]`. It is only legal to
call `get*` if
@@ -247,6 +268,16 @@ public class VariantUtil {
case DECIMAL8:
case DECIMAL16:
return Type.DECIMAL;
+ case DATE:
+ return Type.DATE;
+ case TIMESTAMP:
+ return Type.TIMESTAMP;
+ case TIMESTAMP_NTZ:
+ return Type.TIMESTAMP_NTZ;
+ case FLOAT:
+ return Type.FLOAT;
+ case BINARY:
+ return Type.BINARY;
case LONG_STR:
return Type.STRING;
default:
@@ -283,9 +314,13 @@ public class VariantUtil {
case INT2:
return 3;
case INT4:
+ case DATE:
+ case FLOAT:
return 5;
case INT8:
case DOUBLE:
+ case TIMESTAMP:
+ case TIMESTAMP_NTZ:
return 9;
case DECIMAL4:
return 6;
@@ -293,6 +328,7 @@ public class VariantUtil {
return 10;
case DECIMAL16:
return 18;
+ case BINARY:
case LONG_STR:
return 1 + U32_SIZE + readUnsigned(value, pos + 1, U32_SIZE);
default:
@@ -318,23 +354,31 @@ public class VariantUtil {
}
// Get a long value from variant value `value[pos...]`.
+ // It is only legal to call it if `getType` returns one of
`Type.LONG/DATE/TIMESTAMP/
+ // TIMESTAMP_NTZ`. If the type is `DATE`, the return value is guaranteed to
fit into an int and
+ // represents the number of days from the Unix epoch. If the type is
`TIMESTAMP/TIMESTAMP_NTZ`,
+ // the return value represents the number of microseconds from the Unix
epoch.
// Throw `MALFORMED_VARIANT` if the variant is malformed.
public static long getLong(byte[] value, int pos) {
checkIndex(pos, value.length);
int basicType = value[pos] & BASIC_TYPE_MASK;
int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & TYPE_INFO_MASK;
- if (basicType != PRIMITIVE) throw unexpectedType(Type.LONG);
+ String exceptionMessage = "Expect type to be
LONG/DATE/TIMESTAMP/TIMESTAMP_NTZ";
+ if (basicType != PRIMITIVE) throw new
IllegalStateException(exceptionMessage);
switch (typeInfo) {
case INT1:
return readLong(value, pos + 1, 1);
case INT2:
return readLong(value, pos + 1, 2);
case INT4:
+ case DATE:
return readLong(value, pos + 1, 4);
case INT8:
+ case TIMESTAMP:
+ case TIMESTAMP_NTZ:
return readLong(value, pos + 1, 8);
default:
- throw unexpectedType(Type.LONG);
+ throw new IllegalStateException(exceptionMessage);
}
}
@@ -380,6 +424,29 @@ public class VariantUtil {
return result.stripTrailingZeros();
}
+ // Get a float value from variant value `value[pos...]`.
+ // Throw `MALFORMED_VARIANT` if the variant is malformed.
+ public static float getFloat(byte[] value, int pos) {
+ checkIndex(pos, value.length);
+ int basicType = value[pos] & BASIC_TYPE_MASK;
+ int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & TYPE_INFO_MASK;
+ if (basicType != PRIMITIVE || typeInfo != FLOAT) throw
unexpectedType(Type.FLOAT);
+ return Float.intBitsToFloat((int) readLong(value, pos + 1, 4));
+ }
+
+ // Get a binary value from variant value `value[pos...]`.
+ // Throw `MALFORMED_VARIANT` if the variant is malformed.
+ public static byte[] getBinary(byte[] value, int pos) {
+ checkIndex(pos, value.length);
+ int basicType = value[pos] & BASIC_TYPE_MASK;
+ int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & TYPE_INFO_MASK;
+ if (basicType != PRIMITIVE || typeInfo != BINARY) throw
unexpectedType(Type.BINARY);
+ int start = pos + 1 + U32_SIZE;
+ int length = readUnsigned(value, pos + 1, U32_SIZE);
+ checkIndex(start + length - 1, value.length);
+ return Arrays.copyOfRange(value, start, start + length);
+ }
+
// Get a string value from variant value `value[pos...]`.
// Throw `MALFORMED_VARIANT` if the variant is malformed.
public static String getString(byte[] value, int pos) {
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 8a077d9e9acb..94cf7130d485 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -1114,7 +1114,7 @@ case class Cast(
_ => throw QueryExecutionErrors.cannotCastFromNullTypeError(to)
} else if (from.isInstanceOf[VariantType]) {
buildCast[VariantVal](_, v => {
- variant.VariantGet.cast(v, to, evalMode != EvalMode.TRY, timeZoneId)
+ variant.VariantGet.cast(v, to, evalMode != EvalMode.TRY, timeZoneId,
zoneId)
})
} else {
to match {
@@ -1211,11 +1211,12 @@ case class Cast(
case _ if from.isInstanceOf[VariantType] => (c, evPrim, evNull) =>
val tmp = ctx.freshVariable("tmp", classOf[Object])
val dataTypeArg = ctx.addReferenceObj("dataType", to)
- val zoneIdArg = ctx.addReferenceObj("zoneId", timeZoneId)
+ val zoneStrArg = ctx.addReferenceObj("zoneStr", timeZoneId)
+ val zoneIdArg = ctx.addReferenceObj("zoneId", zoneId,
classOf[ZoneId].getName)
val failOnError = evalMode != EvalMode.TRY
val cls = classOf[variant.VariantGet].getName
code"""
- Object $tmp = $cls.cast($c, $dataTypeArg, $failOnError, $zoneIdArg);
+ Object $tmp = $cls.cast($c, $dataTypeArg, $failOnError, $zoneStrArg,
$zoneIdArg);
if ($tmp == null) {
$evNull = true;
} else {
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
index c5e316dc6c8c..8b09bf5f7de0 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.expressions.variant
+import java.time.ZoneId
+
import scala.util.parsing.combinator.RegexParsers
import org.apache.spark.SparkRuntimeException
@@ -170,7 +172,8 @@ case class VariantGet(
parsedPath,
dataType,
failOnError,
- timeZoneId)
+ timeZoneId,
+ zoneId)
}
protected override def doGenCode(ctx: CodegenContext, ev: ExprCode):
ExprCode = {
@@ -178,14 +181,15 @@ case class VariantGet(
val tmp = ctx.freshVariable("tmp", classOf[Object])
val parsedPathArg = ctx.addReferenceObj("parsedPath", parsedPath)
val dataTypeArg = ctx.addReferenceObj("dataType", dataType)
- val zoneIdArg = ctx.addReferenceObj("zoneId", timeZoneId)
+ val zoneStrArg = ctx.addReferenceObj("zoneStr", timeZoneId)
+ val zoneIdArg = ctx.addReferenceObj("zoneId", zoneId,
classOf[ZoneId].getName)
val code = code"""
${childCode.code}
boolean ${ev.isNull} = ${childCode.isNull};
${CodeGenerator.javaType(dataType)} ${ev.value} =
${CodeGenerator.defaultValue(dataType)};
if (!${ev.isNull}) {
Object $tmp =
org.apache.spark.sql.catalyst.expressions.variant.VariantGet.variantGet(
- ${childCode.value}, $parsedPathArg, $dataTypeArg, $failOnError,
$zoneIdArg);
+ ${childCode.value}, $parsedPathArg, $dataTypeArg, $failOnError,
$zoneStrArg, $zoneIdArg);
if ($tmp == null) {
${ev.isNull} = true;
} else {
@@ -228,7 +232,8 @@ case object VariantGet {
parsedPath: Array[VariantPathParser.PathSegment],
dataType: DataType,
failOnError: Boolean,
- zoneId: Option[String]): Any = {
+ zoneStr: Option[String],
+ zoneId: ZoneId): Any = {
var v = new Variant(input.getValue, input.getMetadata)
for (path <- parsedPath) {
v = path match {
@@ -238,7 +243,7 @@ case object VariantGet {
}
if (v == null) return null
}
- VariantGet.cast(v, dataType, failOnError, zoneId)
+ VariantGet.cast(v, dataType, failOnError, zoneStr, zoneId)
}
/**
@@ -249,9 +254,10 @@ case object VariantGet {
input: VariantVal,
dataType: DataType,
failOnError: Boolean,
- zoneId: Option[String]): Any = {
+ zoneStr: Option[String],
+ zoneId: ZoneId): Any = {
val v = new Variant(input.getValue, input.getMetadata)
- VariantGet.cast(v, dataType, failOnError, zoneId)
+ VariantGet.cast(v, dataType, failOnError, zoneStr, zoneId)
}
/**
@@ -261,9 +267,19 @@ case object VariantGet {
* "hello" to int). If the cast fails, throw an exception when `failOnError`
is true, or return a
* SQL NULL when it is false.
*/
- def cast(v: Variant, dataType: DataType, failOnError: Boolean, zoneId:
Option[String]): Any = {
- def invalidCast(): Any =
- if (failOnError) throw QueryExecutionErrors.invalidVariantCast(v.toJson,
dataType) else null
+ def cast(
+ v: Variant,
+ dataType: DataType,
+ failOnError: Boolean,
+ zoneStr: Option[String],
+ zoneId: ZoneId): Any = {
+ def invalidCast(): Any = {
+ if (failOnError) {
+ throw QueryExecutionErrors.invalidVariantCast(v.toJson(zoneId),
dataType)
+ } else {
+ null
+ }
+ }
if (dataType == VariantType) return new VariantVal(v.getValue,
v.getMetadata)
val variantType = v.getType
@@ -273,15 +289,22 @@ case object VariantGet {
val input = variantType match {
case Type.OBJECT | Type.ARRAY =>
return if (dataType.isInstanceOf[StringType]) {
- UTF8String.fromString(v.toJson)
+ UTF8String.fromString(v.toJson(zoneId))
} else {
invalidCast()
}
- case Type.BOOLEAN => v.getBoolean
- case Type.LONG => v.getLong
- case Type.STRING => UTF8String.fromString(v.getString)
- case Type.DOUBLE => v.getDouble
- case Type.DECIMAL => Decimal(v.getDecimal)
+ case Type.BOOLEAN => Literal(v.getBoolean, BooleanType)
+ case Type.LONG => Literal(v.getLong, LongType)
+ case Type.STRING => Literal(UTF8String.fromString(v.getString),
StringType)
+ case Type.DOUBLE => Literal(v.getDouble, DoubleType)
+ case Type.DECIMAL =>
+ val d = Decimal(v.getDecimal)
+ Literal(Decimal(v.getDecimal), DecimalType(d.precision, d.scale))
+ case Type.DATE => Literal(v.getLong.toInt, DateType)
+ case Type.TIMESTAMP => Literal(v.getLong, TimestampType)
+ case Type.TIMESTAMP_NTZ => Literal(v.getLong, TimestampNTZType)
+ case Type.FLOAT => Literal(v.getFloat, FloatType)
+ case Type.BINARY => Literal(v.getBinary, BinaryType)
// We have handled other cases and should never reach here. This
case is only intended
// to by pass the compiler exhaustiveness check.
case _ => throw QueryExecutionErrors.unreachableError()
@@ -289,15 +312,17 @@ case object VariantGet {
// We mostly use the `Cast` expression to implement the cast. However,
`Cast` silently
// ignores the overflow in the long/decimal -> timestamp cast, and we
want to enforce
// strict overflow checks.
- input match {
- case l: Long if dataType == TimestampType =>
- try Math.multiplyExact(l, MICROS_PER_SECOND)
+ input.dataType match {
+ case LongType if dataType == TimestampType =>
+ try Math.multiplyExact(input.value.asInstanceOf[Long],
MICROS_PER_SECOND)
catch {
case _: ArithmeticException => invalidCast()
}
- case d: Decimal if dataType == TimestampType =>
+ case _: DecimalType if dataType == TimestampType =>
try {
- d.toJavaBigDecimal
+ input.value
+ .asInstanceOf[Decimal]
+ .toJavaBigDecimal
.multiply(new java.math.BigDecimal(MICROS_PER_SECOND))
.toBigInteger
.longValueExact()
@@ -305,9 +330,8 @@ case object VariantGet {
case _: ArithmeticException => invalidCast()
}
case _ =>
- val inputLiteral = Literal(input)
- if (Cast.canAnsiCast(inputLiteral.dataType, dataType)) {
- val result = Cast(inputLiteral, dataType, zoneId,
EvalMode.TRY).eval()
+ if (Cast.canAnsiCast(input.dataType, dataType)) {
+ val result = Cast(input, dataType, zoneStr, EvalMode.TRY).eval()
if (result == null) invalidCast() else result
} else {
invalidCast()
@@ -318,7 +342,7 @@ case object VariantGet {
val size = v.arraySize()
val array = new Array[Any](size)
for (i <- 0 until size) {
- array(i) = cast(v.getElementAtIndex(i), elementType, failOnError,
zoneId)
+ array(i) = cast(v.getElementAtIndex(i), elementType, failOnError,
zoneStr, zoneId)
}
new GenericArrayData(array)
} else {
@@ -332,7 +356,7 @@ case object VariantGet {
for (i <- 0 until size) {
val field = v.getFieldAtIndex(i)
keyArray(i) = UTF8String.fromString(field.key)
- valueArray(i) = cast(field.value, valueType, failOnError, zoneId)
+ valueArray(i) = cast(field.value, valueType, failOnError, zoneStr,
zoneId)
}
ArrayBasedMapData(keyArray, valueArray)
} else {
@@ -345,7 +369,8 @@ case object VariantGet {
val field = v.getFieldAtIndex(i)
st.getFieldIndex(field.key) match {
case Some(idx) =>
- row.update(idx, cast(field.value, fields(idx).dataType,
failOnError, zoneId))
+ row.update(idx,
+ cast(field.value, fields(idx).dataType, failOnError,
zoneStr, zoneId))
case _ =>
}
}
@@ -576,6 +601,11 @@ object SchemaOfVariant {
case Type.DECIMAL =>
val d = v.getDecimal
DecimalType(d.precision(), d.scale())
+ case Type.DATE => DateType
+ case Type.TIMESTAMP => TimestampType
+ case Type.TIMESTAMP_NTZ => TimestampNTZType
+ case Type.FLOAT => FloatType
+ case Type.BINARY => BinaryType
}
/**
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala
index 1964b5f24b34..80f2b2a0070c 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala
@@ -317,7 +317,7 @@ class JacksonGenerator(
}
def write(v: VariantVal): Unit = {
- gen.writeRawValue(v.toString)
+ gen.writeRawValue(v.toJson(options.zoneId))
}
def writeLineEnding(): Unit = {
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala
index a5863e80a26c..24675518646d 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala
@@ -23,6 +23,7 @@ import org.apache.spark.{SparkFunSuite, SparkRuntimeException}
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -685,4 +686,115 @@ class VariantExpressionSuite extends SparkFunSuite with
ExpressionEvalHelper {
Array(null, null, 1)
)
}
+
+ test("atomic types that are not produced by parse_json") {
+ // Dictionary size is `0` for value 0. An empty dictionary contains one
offset `0` for the
+ // one-past-the-end position (i.e. the sum of all string lengths).
+ val emptyMetadata = Array[Byte](VERSION, 0, 0)
+
+ def checkToJson(value: Array[Byte], expected: String): Unit = {
+ val input = Literal(new VariantVal(value, emptyMetadata))
+ checkEvaluation(StructsToJson(Map.empty, input), expected)
+ }
+
+ def checkCast(value: Array[Byte], dataType: DataType, expected: Any): Unit
= {
+ val input = Literal(new VariantVal(value, emptyMetadata))
+ checkEvaluation(Cast(input, dataType, evalMode = EvalMode.ANSI),
expected)
+ }
+
+ checkToJson(Array(primitiveHeader(DATE), 0, 0, 0, 0), "\"1970-01-01\"")
+ checkToJson(Array(primitiveHeader(DATE), -1, -1, -1, 127),
"\"+5881580-07-11\"")
+ checkToJson(Array(primitiveHeader(DATE), 0, 0, 0, -128),
"\"-5877641-06-23\"")
+ withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") {
+ checkCast(Array(primitiveHeader(DATE), 0, 0, 0, 0), TimestampType, 0L)
+ checkCast(Array(primitiveHeader(DATE), 1, 0, 0, 0), TimestampType,
MICROS_PER_DAY)
+ }
+ withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Los_Angeles") {
+ checkCast(Array(primitiveHeader(DATE), 0, 0, 0, 0), TimestampType, 8 *
MICROS_PER_HOUR)
+ checkCast(Array(primitiveHeader(DATE), 1, 0, 0, 0), TimestampType,
+ MICROS_PER_DAY + 8 * MICROS_PER_HOUR)
+ }
+
+ def littleEndianLong(value: Long): Array[Byte] =
+ BigInt(value).toByteArray.reverse.padTo(8, 0.toByte)
+
+ val time1 = littleEndianLong(0)
+ // In America/Los_Angeles timezone, timestamp value `skippedTime` is
2011-03-13 03:00:00.
+ // The next second of 2011-03-13 01:59:59 jumps to 2011-03-13 03:00:00.
+ val skippedTime = 1300010400000000L
+ val time2 = littleEndianLong(skippedTime)
+ val time3 = littleEndianLong(skippedTime - 1)
+ val time4 = littleEndianLong(Long.MinValue)
+ val time5 = littleEndianLong(Long.MaxValue)
+ val time6 = littleEndianLong(-62198755200000000L)
+ val timestampHeader = Array(primitiveHeader(TIMESTAMP))
+ val timestampNtzHeader = Array(primitiveHeader(TIMESTAMP_NTZ))
+
+ for (timeZone <- Seq("UTC", "America/Los_Angeles")) {
+ withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> timeZone) {
+ checkToJson(timestampNtzHeader ++ time1, "\"1970-01-01 00:00:00\"")
+ checkToJson(timestampNtzHeader ++ time2, "\"2011-03-13 10:00:00\"")
+ checkToJson(timestampNtzHeader ++ time3, "\"2011-03-13
09:59:59.999999\"")
+ checkToJson(timestampNtzHeader ++ time4, "\"-290308-12-21
19:59:05.224192\"")
+ checkToJson(timestampNtzHeader ++ time5, "\"+294247-01-10
04:00:54.775807\"")
+ checkToJson(timestampNtzHeader ++ time6, "\"-0001-01-01 00:00:00\"")
+
+ checkCast(timestampNtzHeader ++ time1, DateType, 0)
+ checkCast(timestampNtzHeader ++ time2, DateType, 15046)
+ checkCast(timestampNtzHeader ++ time3, DateType, 15046)
+ checkCast(timestampNtzHeader ++ time4, DateType, -106751992)
+ checkCast(timestampNtzHeader ++ time5, DateType, 106751991)
+ checkCast(timestampNtzHeader ++ time6, DateType, -719893)
+ }
+ }
+
+ withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") {
+ checkToJson(timestampHeader ++ time1, "\"1970-01-01 00:00:00+00:00\"")
+ checkToJson(timestampHeader ++ time2, "\"2011-03-13 10:00:00+00:00\"")
+ checkToJson(timestampHeader ++ time3, "\"2011-03-13
09:59:59.999999+00:00\"")
+ checkToJson(timestampHeader ++ time4, "\"-290308-12-21
19:59:05.224192+00:00\"")
+ checkToJson(timestampHeader ++ time5, "\"+294247-01-10
04:00:54.775807+00:00\"")
+ checkToJson(timestampHeader ++ time6, "\"-0001-01-01 00:00:00+00:00\"")
+
+ checkCast(timestampHeader ++ time1, DateType, 0)
+ checkCast(timestampHeader ++ time2, DateType, 15046)
+ checkCast(timestampHeader ++ time3, DateType, 15046)
+ checkCast(timestampHeader ++ time4, DateType, -106751992)
+ checkCast(timestampHeader ++ time5, DateType, 106751991)
+ checkCast(timestampHeader ++ time6, DateType, -719893)
+ }
+
+ withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Los_Angeles") {
+ checkToJson(timestampHeader ++ time1, "\"1969-12-31 16:00:00-08:00\"")
+ checkToJson(timestampHeader ++ time2, "\"2011-03-13 03:00:00-07:00\"")
+ checkToJson(timestampHeader ++ time3, "\"2011-03-13
01:59:59.999999-08:00\"")
+ checkToJson(timestampHeader ++ time4, "\"-290308-12-21
12:06:07.224192-07:52\"")
+ checkToJson(timestampHeader ++ time5, "\"+294247-01-09
20:00:54.775807-08:00\"")
+ checkToJson(timestampHeader ++ time6, "\"-0002-12-31 16:07:02-07:52\"")
+
+ checkCast(timestampHeader ++ time1, DateType, -1)
+ checkCast(timestampHeader ++ time2, DateType, 15046)
+ checkCast(timestampHeader ++ time3, DateType, 15046)
+ checkCast(timestampHeader ++ time4, DateType, -106751992)
+ checkCast(timestampHeader ++ time5, DateType, 106751990)
+ checkCast(timestampHeader ++ time6, DateType, -719894)
+ }
+
+ checkToJson(Array(primitiveHeader(FLOAT)) ++
+ BigInt(java.lang.Float.floatToIntBits(1.23F)).toByteArray.reverse,
"1.23")
+ checkToJson(Array(primitiveHeader(FLOAT)) ++
+ BigInt(java.lang.Float.floatToIntBits(-0.0F)).toByteArray.reverse,
"-0.0")
+ // Note: 1.23F.toDouble != 1.23.
+ checkCast(Array(primitiveHeader(FLOAT)) ++
+ BigInt(java.lang.Float.floatToIntBits(1.23F)).toByteArray.reverse,
DoubleType, 1.23F.toDouble)
+
+ checkToJson(Array(primitiveHeader(BINARY), 0, 0, 0, 0), "\"\"")
+ checkToJson(Array(primitiveHeader(BINARY), 1, 0, 0, 0, 1), "\"AQ==\"")
+ checkToJson(Array(primitiveHeader(BINARY), 2, 0, 0, 0, 1, 2), "\"AQI=\"")
+ checkToJson(Array(primitiveHeader(BINARY), 3, 0, 0, 0, 1, 2, 3),
"\"AQID\"")
+ checkCast(Array(primitiveHeader(BINARY), 3, 0, 0, 0, 1, 2, 3), StringType,
+ "\u0001\u0002\u0003")
+ checkCast(Array(primitiveHeader(BINARY), 5, 0, 0, 0, 72, 101, 108, 108,
111), StringType,
+ "Hello")
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]