This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new ddf2da74f52 [SPARK-41993][SQL] Move RowEncoder to AgnosticEncoders
ddf2da74f52 is described below
commit ddf2da74f527ee00af99fe3928015149f9477734
Author: Herman van Hovell <[email protected]>
AuthorDate: Tue Jan 17 10:52:28 2023 -0800
[SPARK-41993][SQL] Move RowEncoder to AgnosticEncoders
### What changes were proposed in this pull request?
This PR makes `RowEncoder` produce an `AgnosticEncoder`. The expression
generation for these encoders is moved to `ScalaReflection` (this will be moved
out in a subsequent PR).
The generated serializer and deserializer expressions will slightly change
for both schema and type based encoders. These are not semantically different
from the old expressions. Concretely the following changes have been introduced:
- There is more type validation in maps/arrays/seqs for type based
encoders. This should be a positive change, since it disallows users to pass
wrong data through erasure hackd.
- Array/Seq serialization is a bit more strict. In the old scenario it was
possible to pass in sequences/arrays with the wrong type and/or nullability.
### Why are the changes needed?
For the Spark Connect Scala Client we also want to be able to use `Row`
based results.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
This is a refactoring so mostly existing tests. I have added test to the
catalyst tests that triggered failures downstream (typed arrays in
`WrappedArray` & `Seq[_]` change in Scala 2.13).
Closes #39627 from hvanhovell/SPARK-41993-v2.
Authored-by: Herman van Hovell <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
---
.../spark/sql/catalyst/JavaTypeInference.scala | 4 +-
.../spark/sql/catalyst/ScalaReflection.scala | 317 ++++++++++++------
.../spark/sql/catalyst/SerializerBuildHelper.scala | 25 +-
.../sql/catalyst/encoders/AgnosticEncoder.scala | 128 ++++++--
.../sql/catalyst/encoders/ExpressionEncoder.scala | 5 +-
.../spark/sql/catalyst/encoders/RowEncoder.scala | 354 ++++-----------------
.../sql/catalyst/expressions/objects/objects.scala | 87 +++--
.../spark/sql/catalyst/ScalaReflectionSuite.scala | 9 +-
.../catalyst/encoders/ExpressionEncoderSuite.scala | 2 +
.../sql/catalyst/encoders/RowEncoderSuite.scala | 24 ++
.../catalyst/expressions/CodeGenerationSuite.scala | 2 +-
.../expressions/ObjectExpressionsSuite.scala | 9 +-
12 files changed, 462 insertions(+), 504 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index 827807055ce..81f363dda36 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -423,10 +423,10 @@ object JavaTypeInference {
case c if c == classOf[java.time.Period] =>
createSerializerForJavaPeriod(inputObject)
case c if c == classOf[java.math.BigInteger] =>
- createSerializerForJavaBigInteger(inputObject)
+ createSerializerForBigInteger(inputObject)
case c if c == classOf[java.math.BigDecimal] =>
- createSerializerForJavaBigDecimal(inputObject)
+ createSerializerForBigDecimal(inputObject)
case c if c == classOf[java.lang.Boolean] =>
createSerializerForBoolean(inputObject)
case c if c == classOf[java.lang.Byte] =>
createSerializerForByte(inputObject)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index e02e42cea1a..42208cd1098 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst
import javax.lang.model.SourceVersion
import scala.annotation.tailrec
+import scala.language.existentials
import scala.reflect.ClassTag
import scala.reflect.internal.Symbols
import scala.util.{Failure, Success}
@@ -27,12 +28,13 @@ import scala.util.{Failure, Success}
import org.apache.commons.lang3.reflect.ConstructorUtils
import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.{expressions => exprs}
import org.apache.spark.sql.catalyst.DeserializerBuildHelper._
import org.apache.spark.sql.catalyst.SerializerBuildHelper._
import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
-import org.apache.spark.sql.catalyst.expressions.{Expression, _}
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.errors.QueryExecutionErrors
@@ -82,12 +84,24 @@ object ScalaReflection extends ScalaReflection {
}
}
- // TODO this name is slightly misleading. This returns the input
- // data type we expect to see during serialization.
- private[catalyst] def dataTypeFor(enc: AgnosticEncoder[_]): DataType = {
+ /**
+ * Return the data type we expect to see when deserializing a value with
encoder `enc`.
+ */
+ private[catalyst] def externalDataTypeFor(enc: AgnosticEncoder[_]): DataType
= {
+ externalDataTypeFor(enc, lenientSerialization = false)
+ }
+
+ private[catalyst] def lenientExternalDataTypeFor(enc: AgnosticEncoder[_]):
DataType =
+ externalDataTypeFor(enc, enc.lenientSerialization)
+
+ private def externalDataTypeFor(
+ enc: AgnosticEncoder[_],
+ lenientSerialization: Boolean): DataType = {
// DataType can be native.
if (isNativeEncoder(enc)) {
enc.dataType
+ } else if (lenientSerialization) {
+ ObjectType(classOf[java.lang.Object])
} else {
ObjectType(enc.clsTag.runtimeClass)
}
@@ -123,7 +137,7 @@ object ScalaReflection extends ScalaReflection {
case NullEncoder => true
case CalendarIntervalEncoder => true
case BinaryEncoder => true
- case SparkDecimalEncoder => true
+ case _: SparkDecimalEncoder => true
case _ => false
}
@@ -155,11 +169,19 @@ object ScalaReflection extends ScalaReflection {
val walkedTypePath =
WalkedTypePath().recordRoot(enc.clsTag.runtimeClass.getName)
// Assumes we are deserializing the first column of a row.
val input = GetColumnByOrdinal(0, enc.dataType)
- val deserializer = deserializerFor(
- enc,
- upCastToExpectedType(input, enc.dataType, walkedTypePath),
- walkedTypePath)
- expressionWithNullSafety(deserializer, enc.nullable, walkedTypePath)
+ enc match {
+ case RowEncoder(fields) =>
+ val children = fields.zipWithIndex.map { case (f, i) =>
+ deserializerFor(f.enc, GetStructField(input, i), walkedTypePath)
+ }
+ CreateExternalRow(children, enc.schema)
+ case _ =>
+ val deserializer = deserializerFor(
+ enc,
+ upCastToExpectedType(input, enc.dataType, walkedTypePath),
+ walkedTypePath)
+ expressionWithNullSafety(deserializer, enc.nullable, walkedTypePath)
+ }
}
/**
@@ -178,19 +200,7 @@ object ScalaReflection extends ScalaReflection {
walkedTypePath: WalkedTypePath): Expression = enc match {
case _ if isNativeEncoder(enc) =>
path
- case BoxedBooleanEncoder =>
- createDeserializerForTypesSupportValueOf(path, enc.clsTag.runtimeClass)
- case BoxedByteEncoder =>
- createDeserializerForTypesSupportValueOf(path, enc.clsTag.runtimeClass)
- case BoxedShortEncoder =>
- createDeserializerForTypesSupportValueOf(path, enc.clsTag.runtimeClass)
- case BoxedIntEncoder =>
- createDeserializerForTypesSupportValueOf(path, enc.clsTag.runtimeClass)
- case BoxedLongEncoder =>
- createDeserializerForTypesSupportValueOf(path, enc.clsTag.runtimeClass)
- case BoxedFloatEncoder =>
- createDeserializerForTypesSupportValueOf(path, enc.clsTag.runtimeClass)
- case BoxedDoubleEncoder =>
+ case _: BoxedLeafEncoder[_, _] =>
createDeserializerForTypesSupportValueOf(path, enc.clsTag.runtimeClass)
case JavaEnumEncoder(tag) =>
val toString = createDeserializerForString(path, returnNullable = false)
@@ -204,9 +214,9 @@ object ScalaReflection extends ScalaReflection {
returnNullable = false)
case StringEncoder =>
createDeserializerForString(path, returnNullable = false)
- case ScalaDecimalEncoder =>
+ case _: ScalaDecimalEncoder =>
createDeserializerForScalaBigDecimal(path, returnNullable = false)
- case JavaDecimalEncoder =>
+ case _: JavaDecimalEncoder =>
createDeserializerForJavaBigDecimal(path, returnNullable = false)
case ScalaBigIntEncoder =>
createDeserializerForScalaBigInt(path)
@@ -216,13 +226,13 @@ object ScalaReflection extends ScalaReflection {
createDeserializerForDuration(path)
case YearMonthIntervalEncoder =>
createDeserializerForPeriod(path)
- case DateEncoder =>
+ case _: DateEncoder =>
createDeserializerForSqlDate(path)
- case LocalDateEncoder =>
+ case _: LocalDateEncoder =>
createDeserializerForLocalDate(path)
- case TimestampEncoder =>
+ case _: TimestampEncoder =>
createDeserializerForSqlTimestamp(path)
- case InstantEncoder =>
+ case _: InstantEncoder =>
createDeserializerForInstant(path)
case LocalDateTimeEncoder =>
createDeserializerForLocalDateTime(path)
@@ -232,39 +242,29 @@ object ScalaReflection extends ScalaReflection {
case OptionEncoder(valueEnc) =>
val newTypePath =
walkedTypePath.recordOption(valueEnc.clsTag.runtimeClass.getName)
val deserializer = deserializerFor(valueEnc, path, newTypePath)
- WrapOption(deserializer, dataTypeFor(valueEnc))
-
- case ArrayEncoder(elementEnc: AgnosticEncoder[_]) =>
- val newTypePath =
walkedTypePath.recordArray(elementEnc.clsTag.runtimeClass.getName)
- val mapFunction: Expression => Expression = element => {
- // upcast the array element to the data type the encoder expected.
- deserializerForWithNullSafetyAndUpcast(
- element,
- elementEnc.dataType,
- nullable = elementEnc.nullable,
- newTypePath,
- deserializerFor(elementEnc, _, newTypePath))
- }
+ WrapOption(deserializer, externalDataTypeFor(valueEnc))
+
+ case ArrayEncoder(elementEnc: AgnosticEncoder[_], containsNull) =>
Invoke(
- UnresolvedMapObjects(mapFunction, path),
+ deserializeArray(
+ path,
+ elementEnc,
+ containsNull,
+ None,
+ walkedTypePath),
toArrayMethodName(elementEnc),
ObjectType(enc.clsTag.runtimeClass),
returnNullable = false)
- case IterableEncoder(clsTag, elementEnc) =>
- val newTypePath =
walkedTypePath.recordArray(elementEnc.clsTag.runtimeClass.getName)
- val mapFunction: Expression => Expression = element => {
- // upcast the array element to the data type the encoder expected.
- deserializerForWithNullSafetyAndUpcast(
- element,
- elementEnc.dataType,
- nullable = elementEnc.nullable,
- newTypePath,
- deserializerFor(elementEnc, _, newTypePath))
- }
- UnresolvedMapObjects(mapFunction, path, Some(clsTag.runtimeClass))
+ case IterableEncoder(clsTag, elementEnc, containsNull, _) =>
+ deserializeArray(
+ path,
+ elementEnc,
+ containsNull,
+ Option(clsTag.runtimeClass),
+ walkedTypePath)
- case MapEncoder(tag, keyEncoder, valueEncoder) =>
+ case MapEncoder(tag, keyEncoder, valueEncoder, _) =>
val newTypePath = walkedTypePath.recordMap(
keyEncoder.clsTag.runtimeClass.getName,
valueEncoder.clsTag.runtimeClass.getName)
@@ -298,6 +298,39 @@ object ScalaReflection extends ScalaReflection {
IsNull(path),
expressions.Literal.create(null, dt),
NewInstance(cls, arguments, dt, propagateNull = false))
+
+ case RowEncoder(fields) =>
+ val convertedFields = fields.zipWithIndex.map { case (f, i) =>
+ val newTypePath = walkedTypePath.recordField(
+ f.enc.clsTag.runtimeClass.getName,
+ f.name)
+ exprs.If(
+ Invoke(path, "isNullAt", BooleanType, exprs.Literal(i) :: Nil),
+ exprs.Literal.create(null, externalDataTypeFor(f.enc)),
+ deserializerFor(f.enc, GetStructField(path, i), newTypePath))
+ }
+ exprs.If(IsNull(path),
+ exprs.Literal.create(null, externalDataTypeFor(enc)),
+ CreateExternalRow(convertedFields, enc.schema))
+ }
+
+ private def deserializeArray(
+ path: Expression,
+ elementEnc: AgnosticEncoder[_],
+ containsNull: Boolean,
+ cls: Option[Class[_]],
+ walkedTypePath: WalkedTypePath): Expression = {
+ val newTypePath =
walkedTypePath.recordArray(elementEnc.clsTag.runtimeClass.getName)
+ val mapFunction: Expression => Expression = element => {
+ // upcast the array element to the data type the encoder expects.
+ deserializerForWithNullSafetyAndUpcast(
+ element,
+ elementEnc.dataType,
+ nullable = containsNull,
+ newTypePath,
+ deserializerFor(elementEnc, _, newTypePath))
+ }
+ UnresolvedMapObjects(mapFunction, path, cls)
}
/**
@@ -306,7 +339,7 @@ object ScalaReflection extends ScalaReflection {
* input object is located at ordinal 0 of a row, i.e., `BoundReference(0,
_)`.
*/
def serializerFor(enc: AgnosticEncoder[_]): Expression = {
- val input = BoundReference(0, dataTypeFor(enc), nullable = enc.nullable)
+ val input = BoundReference(0, lenientExternalDataTypeFor(enc), nullable =
enc.nullable)
serializerFor(enc, input)
}
@@ -327,45 +360,52 @@ object ScalaReflection extends ScalaReflection {
case JavaEnumEncoder(_) => createSerializerForJavaEnum(input)
case ScalaEnumEncoder(_, _) => createSerializerForScalaEnum(input)
case StringEncoder => createSerializerForString(input)
- case ScalaDecimalEncoder => createSerializerForScalaBigDecimal(input)
- case JavaDecimalEncoder => createSerializerForJavaBigDecimal(input)
- case ScalaBigIntEncoder => createSerializerForScalaBigInt(input)
- case JavaBigIntEncoder => createSerializerForJavaBigInteger(input)
+ case ScalaDecimalEncoder(dt) => createSerializerForBigDecimal(input, dt)
+ case JavaDecimalEncoder(dt, false) => createSerializerForBigDecimal(input,
dt)
+ case JavaDecimalEncoder(dt, true) => createSerializerForAnyDecimal(input,
dt)
+ case ScalaBigIntEncoder => createSerializerForBigInteger(input)
+ case JavaBigIntEncoder => createSerializerForBigInteger(input)
case DayTimeIntervalEncoder => createSerializerForJavaDuration(input)
case YearMonthIntervalEncoder => createSerializerForJavaPeriod(input)
- case DateEncoder => createSerializerForSqlDate(input)
- case LocalDateEncoder => createSerializerForJavaLocalDate(input)
- case TimestampEncoder => createSerializerForSqlTimestamp(input)
- case InstantEncoder => createSerializerForJavaInstant(input)
+ case DateEncoder(true) | LocalDateEncoder(true) =>
createSerializerForAnyDate(input)
+ case DateEncoder(false) => createSerializerForSqlDate(input)
+ case LocalDateEncoder(false) => createSerializerForJavaLocalDate(input)
+ case TimestampEncoder(true) | InstantEncoder(true) =>
createSerializerForAnyTimestamp(input)
+ case TimestampEncoder(false) => createSerializerForSqlTimestamp(input)
+ case InstantEncoder(false) => createSerializerForJavaInstant(input)
case LocalDateTimeEncoder => createSerializerForLocalDateTime(input)
case UDTEncoder(udt, udtClass) =>
createSerializerForUserDefinedType(input, udt, udtClass)
case OptionEncoder(valueEnc) =>
- serializerFor(valueEnc, UnwrapOption(dataTypeFor(valueEnc), input))
+ serializerFor(valueEnc, UnwrapOption(externalDataTypeFor(valueEnc),
input))
- case ArrayEncoder(elementEncoder) =>
- serializerForArray(isArray = true, elementEncoder, input)
+ case ArrayEncoder(elementEncoder, containsNull) =>
+ if (elementEncoder.isPrimitive) {
+ createSerializerForPrimitiveArray(input, elementEncoder.dataType)
+ } else {
+ serializerForArray(elementEncoder, containsNull, input,
lenientSerialization = false)
+ }
- case IterableEncoder(ctag, elementEncoder) =>
+ case IterableEncoder(ctag, elementEncoder, containsNull,
lenientSerialization) =>
val getter = if
(classOf[scala.collection.Set[_]].isAssignableFrom(ctag.runtimeClass)) {
// There's no corresponding Catalyst type for `Set`, we serialize a
`Set` to Catalyst array.
// Note that the property of `Set` is only kept when manipulating the
data as domain object.
- Invoke(input, "toSeq", ObjectType(classOf[Seq[_]]))
+ Invoke(input, "toSeq", ObjectType(classOf[scala.collection.Seq[_]]))
} else {
input
}
- serializerForArray(isArray = false, elementEncoder, getter)
+ serializerForArray(elementEncoder, containsNull, getter,
lenientSerialization)
- case MapEncoder(_, keyEncoder, valueEncoder) =>
+ case MapEncoder(_, keyEncoder, valueEncoder, valueContainsNull) =>
createSerializerForMap(
input,
MapElementInformation(
- dataTypeFor(keyEncoder),
- nullable = !keyEncoder.isPrimitive,
- serializerFor(keyEncoder, _)),
+ ObjectType(classOf[AnyRef]),
+ nullable = keyEncoder.nullable,
+ validateAndSerializeElement(keyEncoder, keyEncoder.nullable)),
MapElementInformation(
- dataTypeFor(valueEncoder),
- nullable = !valueEncoder.isPrimitive,
- serializerFor(valueEncoder, _))
+ ObjectType(classOf[AnyRef]),
+ nullable = valueContainsNull,
+ validateAndSerializeElement(valueEncoder, valueContainsNull))
)
case ProductEncoder(_, fields) =>
@@ -377,25 +417,94 @@ object ScalaReflection extends ScalaReflection {
val getter = Invoke(
KnownNotNull(input),
field.name,
- dataTypeFor(field.enc),
- returnNullable = field.enc.nullable)
+ externalDataTypeFor(field.enc),
+ returnNullable = field.nullable)
field.name -> serializerFor(field.enc, getter)
}
createSerializerForObject(input, serializedFields)
+
+ case RowEncoder(fields) =>
+ val serializedFields = fields.zipWithIndex.map { case (field, index) =>
+ val fieldValue = serializerFor(
+ field.enc,
+ ValidateExternalType(
+ GetExternalRowField(input, index, field.name),
+ field.enc.dataType,
+ lenientExternalDataTypeFor(field.enc)))
+
+ val convertedField = if (field.nullable) {
+ exprs.If(
+ Invoke(input, "isNullAt", BooleanType, exprs.Literal(index) ::
Nil),
+ // Because we strip UDTs, `field.dataType` can be different from
`fieldValue.dataType`.
+ // We should use `fieldValue.dataType` here.
+ exprs.Literal.create(null, fieldValue.dataType),
+ fieldValue
+ )
+ } else {
+ AssertNotNull(fieldValue)
+ }
+ field.name -> convertedField
+ }
+ createSerializerForObject(input, serializedFields)
}
private def serializerForArray(
- isArray: Boolean,
elementEnc: AgnosticEncoder[_],
- input: Expression): Expression = {
- dataTypeFor(elementEnc) match {
- case dt: ObjectType =>
- createSerializerForMapObjects(input, dt, serializerFor(elementEnc, _))
- case dt if isArray && elementEnc.isPrimitive =>
- createSerializerForPrimitiveArray(input, dt)
- case dt =>
- createSerializerForGenericArray(input, dt, elementEnc.nullable)
+ elementNullable: Boolean,
+ input: Expression,
+ lenientSerialization: Boolean): Expression = {
+ // Default serializer for Seq and generic Arrays. This does not work for
primitive arrays.
+ val genericSerializer = createSerializerForMapObjects(
+ input,
+ ObjectType(classOf[AnyRef]),
+ validateAndSerializeElement(elementEnc, elementNullable))
+
+ // Check if it is possible the user can pass a primitive array. This is
the only case when it
+ // is safe to directly convert to an array (for generic arrays and Seqs
the type and the
+ // nullability can be violated). If the user has passed a primitive array
we create a special
+ // code path to deal with these.
+ val primitiveEncoderOption = elementEnc match {
+ case _ if !lenientSerialization => None
+ case enc: PrimitiveLeafEncoder[_] => Option(enc)
+ case enc: BoxedLeafEncoder[_, _] => Option(enc.primitive)
+ case _ => None
}
+ primitiveEncoderOption match {
+ case Some(primitiveEncoder) =>
+ val primitiveArrayClass = primitiveEncoder.clsTag.wrap.runtimeClass
+ val check = Invoke(
+ targetObject = exprs.Literal.fromObject(primitiveArrayClass),
+ functionName = "isInstance",
+ BooleanType,
+ arguments = input :: Nil,
+ propagateNull = false,
+ returnNullable = false)
+ exprs.If(
+ check,
+ // TODO replace this with `createSerializerForPrimitiveArray` as
+ // soon as Cast support ObjectType casts.
+ StaticInvoke(
+ classOf[ArrayData],
+ ArrayType(elementEnc.dataType, containsNull = false),
+ "toArrayData",
+ input :: Nil,
+ propagateNull = false,
+ returnNullable = false),
+ genericSerializer)
+ case None =>
+ genericSerializer
+ }
+ }
+
+ private def validateAndSerializeElement(
+ enc: AgnosticEncoder[_],
+ nullable: Boolean): Expression => Expression = { input =>
+ expressionWithNullSafety(
+ serializerFor(
+ enc,
+ ValidateExternalType(input, enc.dataType,
lenientExternalDataTypeFor(enc))),
+ nullable,
+ WalkedTypePath())
}
/**
@@ -598,8 +707,8 @@ object ScalaReflection extends ScalaReflection {
case StringType => classOf[UTF8String]
case CalendarIntervalType => classOf[CalendarInterval]
case _: StructType => classOf[InternalRow]
- case _: ArrayType => classOf[ArrayType]
- case _: MapType => classOf[MapType]
+ case _: ArrayType => classOf[ArrayData]
+ case _: MapType => classOf[MapData]
case udt: UserDefinedType[_] => javaBoxedType(udt.sqlType)
case ObjectType(cls) => cls
case _ => ScalaReflection.typeBoxedJavaMapping.getOrElse(dt,
classOf[java.lang.Object])
@@ -657,7 +766,11 @@ object ScalaReflection extends ScalaReflection {
case NoSymbol => fallbackClass
case _ => mirror.runtimeClass(t.typeSymbol.asClass)
}
- IterableEncoder(ClassTag(targetClass), encoder)
+ IterableEncoder(
+ ClassTag(targetClass),
+ encoder,
+ encoder.nullable,
+ lenientSerialization = false)
}
baseType(tpe) match {
@@ -698,18 +811,18 @@ object ScalaReflection extends ScalaReflection {
// Leaf encoders
case t if isSubtype(t, localTypeOf[String]) => StringEncoder
- case t if isSubtype(t, localTypeOf[Decimal]) => SparkDecimalEncoder
- case t if isSubtype(t, localTypeOf[BigDecimal]) => ScalaDecimalEncoder
- case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) =>
JavaDecimalEncoder
+ case t if isSubtype(t, localTypeOf[Decimal]) =>
DEFAULT_SPARK_DECIMAL_ENCODER
+ case t if isSubtype(t, localTypeOf[BigDecimal]) =>
DEFAULT_SCALA_DECIMAL_ENCODER
+ case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) =>
DEFAULT_JAVA_DECIMAL_ENCODER
case t if isSubtype(t, localTypeOf[BigInt]) => ScalaBigIntEncoder
case t if isSubtype(t, localTypeOf[java.math.BigInteger]) =>
JavaBigIntEncoder
case t if isSubtype(t, localTypeOf[CalendarInterval]) =>
CalendarIntervalEncoder
case t if isSubtype(t, localTypeOf[java.time.Duration]) =>
DayTimeIntervalEncoder
case t if isSubtype(t, localTypeOf[java.time.Period]) =>
YearMonthIntervalEncoder
- case t if isSubtype(t, localTypeOf[java.sql.Date]) => DateEncoder
- case t if isSubtype(t, localTypeOf[java.time.LocalDate]) =>
LocalDateEncoder
- case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) =>
TimestampEncoder
- case t if isSubtype(t, localTypeOf[java.time.Instant]) => InstantEncoder
+ case t if isSubtype(t, localTypeOf[java.sql.Date]) => STRICT_DATE_ENCODER
+ case t if isSubtype(t, localTypeOf[java.time.LocalDate]) =>
STRICT_LOCAL_DATE_ENCODER
+ case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) =>
STRICT_TIMESTAMP_ENCODER
+ case t if isSubtype(t, localTypeOf[java.time.Instant]) =>
STRICT_INSTANT_ENCODER
case t if isSubtype(t, localTypeOf[java.time.LocalDateTime]) =>
LocalDateTimeEncoder
// UDT encoders
@@ -739,7 +852,7 @@ object ScalaReflection extends ScalaReflection {
elementType,
seenTypeSet,
path.recordArray(getClassNameFromType(elementType)))
- ArrayEncoder(encoder)
+ ArrayEncoder(encoder, encoder.nullable)
case t if isSubtype(t, localTypeOf[scala.collection.Seq[_]]) =>
createIterableEncoder(t, classOf[scala.collection.Seq[_]])
@@ -757,7 +870,7 @@ object ScalaReflection extends ScalaReflection {
valueType,
seenTypeSet,
path.recordValueForMap(getClassNameFromType(valueType)))
- MapEncoder(ClassTag(getClassFromType(t)), keyEncoder, valueEncoder)
+ MapEncoder(ClassTag(getClassFromType(t)), keyEncoder, valueEncoder,
valueEncoder.nullable)
case t if definedByConstructorParams(t) =>
if (seenTypeSet.contains(t)) {
@@ -775,7 +888,7 @@ object ScalaReflection extends ScalaReflection {
fieldType,
seenTypeSet + t,
path.recordField(getClassNameFromType(fieldType), fieldName))
- EncoderField(fieldName, encoder)
+ EncoderField(fieldName, encoder, encoder.nullable, Metadata.empty)
}
ProductEncoder(ClassTag(getClassFromType(t)), params)
case _ =>
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala
index 25f6ce520d9..33b0edb0c44 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala
@@ -158,20 +158,29 @@ object SerializerBuildHelper {
returnNullable = false)
}
- def createSerializerForJavaBigDecimal(inputObject: Expression): Expression =
{
+ def createSerializerForBigDecimal(inputObject: Expression): Expression = {
+ createSerializerForBigDecimal(inputObject, DecimalType.SYSTEM_DEFAULT)
+ }
+
+ def createSerializerForBigDecimal(inputObject: Expression, dt: DecimalType):
Expression = {
CheckOverflow(StaticInvoke(
Decimal.getClass,
- DecimalType.SYSTEM_DEFAULT,
+ dt,
"apply",
inputObject :: Nil,
- returnNullable = false), DecimalType.SYSTEM_DEFAULT, nullOnOverflow)
+ returnNullable = false), dt, nullOnOverflow)
}
- def createSerializerForScalaBigDecimal(inputObject: Expression): Expression
= {
- createSerializerForJavaBigDecimal(inputObject)
+ def createSerializerForAnyDecimal(inputObject: Expression, dt: DecimalType):
Expression = {
+ CheckOverflow(StaticInvoke(
+ Decimal.getClass,
+ dt,
+ "fromDecimal",
+ inputObject :: Nil,
+ returnNullable = false), dt, nullOnOverflow)
}
- def createSerializerForJavaBigInteger(inputObject: Expression): Expression =
{
+ def createSerializerForBigInteger(inputObject: Expression): Expression = {
CheckOverflow(StaticInvoke(
Decimal.getClass,
DecimalType.BigIntDecimal,
@@ -180,10 +189,6 @@ object SerializerBuildHelper {
returnNullable = false), DecimalType.BigIntDecimal, nullOnOverflow)
}
- def createSerializerForScalaBigInt(inputObject: Expression): Expression = {
- createSerializerForJavaBigInteger(inputObject)
- }
-
def createSerializerForPrimitiveArray(
inputObject: Expression,
dataType: DataType): Expression = {
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
index 6081ac8dc28..cdc64f2ddb5 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
@@ -16,28 +16,33 @@
*/
package org.apache.spark.sql.catalyst.encoders
+import java.{sql => jsql}
import java.math.{BigDecimal => JBigDecimal, BigInteger => JBigInt}
import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period}
import scala.reflect.{classTag, ClassTag}
-import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.{Encoder, Row}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
/**
* A non implementation specific encoder. This encoder containers all the
information needed
* to generate an implementation specific encoder (e.g. InternalRow <=> Custom
Object).
+ *
+ * The input of the serialization does not need to match the external type of
the encoder. This is
+ * called lenient serialization. An example of this is lenient date
serialization, in this case both
+ * [[java.sql.Date]] and [[java.time.LocalDate]] are allowed. Deserialization
is never lenient; it
+ * will always produce instance of the external type.
*/
trait AgnosticEncoder[T] extends Encoder[T] {
def isPrimitive: Boolean
def nullable: Boolean = !isPrimitive
def dataType: DataType
override def schema: StructType = StructType(StructField("value", dataType,
nullable) :: Nil)
+ def lenientSerialization: Boolean = false
}
-// TODO check RowEncoder
-// TODO check BeanEncoder
object AgnosticEncoders {
case class OptionEncoder[E](elementEncoder: AgnosticEncoder[E])
extends AgnosticEncoder[Option[E]] {
@@ -46,35 +51,48 @@ object AgnosticEncoders {
override val clsTag: ClassTag[Option[E]] = ClassTag(classOf[Option[E]])
}
- case class ArrayEncoder[E](element: AgnosticEncoder[E])
+ case class ArrayEncoder[E](element: AgnosticEncoder[E], containsNull:
Boolean)
extends AgnosticEncoder[Array[E]] {
override def isPrimitive: Boolean = false
- override def dataType: DataType = ArrayType(element.dataType,
element.nullable)
+ override def dataType: DataType = ArrayType(element.dataType, containsNull)
override val clsTag: ClassTag[Array[E]] = element.clsTag.wrap
}
- case class IterableEncoder[C <: Iterable[E], E](
+ /**
+ * Encoder for collections.
+ *
+ * This encoder can be lenient for [[Row]] encoders. In that case we allow
[[Seq]], primitive
+ * array (if any), and generic arrays as input.
+ */
+ case class IterableEncoder[C, E](
override val clsTag: ClassTag[C],
- element: AgnosticEncoder[E])
+ element: AgnosticEncoder[E],
+ containsNull: Boolean,
+ override val lenientSerialization: Boolean)
extends AgnosticEncoder[C] {
override def isPrimitive: Boolean = false
- override val dataType: DataType = ArrayType(element.dataType,
element.nullable)
+ override val dataType: DataType = ArrayType(element.dataType, containsNull)
}
case class MapEncoder[C, K, V](
override val clsTag: ClassTag[C],
keyEncoder: AgnosticEncoder[K],
- valueEncoder: AgnosticEncoder[V])
+ valueEncoder: AgnosticEncoder[V],
+ valueContainsNull: Boolean)
extends AgnosticEncoder[C] {
override def isPrimitive: Boolean = false
override val dataType: DataType = MapType(
keyEncoder.dataType,
valueEncoder.dataType,
- valueEncoder.nullable)
+ valueContainsNull)
}
- case class EncoderField(name: String, enc: AgnosticEncoder[_]) {
- def structField: StructField = StructField(name, enc.dataType,
enc.nullable)
+ case class EncoderField(
+ name: String,
+ enc: AgnosticEncoder[_],
+ nullable: Boolean,
+ metadata: Metadata) {
+ def structField: StructField = StructField(name, enc.dataType, nullable,
metadata)
}
// This supports both Product and DefinedByConstructorParams
@@ -87,6 +105,13 @@ object AgnosticEncoders {
override def dataType: DataType = schema
}
+ case class RowEncoder(fields: Seq[EncoderField]) extends
AgnosticEncoder[Row] {
+ override def isPrimitive: Boolean = false
+ override val schema: StructType = StructType(fields.map(_.structField))
+ override def dataType: DataType = schema
+ override def clsTag: ClassTag[Row] = classTag[Row]
+ }
+
// This will only work for encoding from/to Sparks' InternalRow format.
// It is here for compatibility.
case class UDTEncoder[E >: Null](
@@ -116,39 +141,74 @@ object AgnosticEncoders {
}
// Primitive encoders
- case object PrimitiveBooleanEncoder extends LeafEncoder[Boolean](BooleanType)
- case object PrimitiveByteEncoder extends LeafEncoder[Byte](ByteType)
- case object PrimitiveShortEncoder extends LeafEncoder[Short](ShortType)
- case object PrimitiveIntEncoder extends LeafEncoder[Int](IntegerType)
- case object PrimitiveLongEncoder extends LeafEncoder[Long](LongType)
- case object PrimitiveFloatEncoder extends LeafEncoder[Float](FloatType)
- case object PrimitiveDoubleEncoder extends LeafEncoder[Double](DoubleType)
+ abstract class PrimitiveLeafEncoder[E : ClassTag](dataType: DataType)
+ extends LeafEncoder[E](dataType)
+ case object PrimitiveBooleanEncoder extends
PrimitiveLeafEncoder[Boolean](BooleanType)
+ case object PrimitiveByteEncoder extends PrimitiveLeafEncoder[Byte](ByteType)
+ case object PrimitiveShortEncoder extends
PrimitiveLeafEncoder[Short](ShortType)
+ case object PrimitiveIntEncoder extends
PrimitiveLeafEncoder[Int](IntegerType)
+ case object PrimitiveLongEncoder extends PrimitiveLeafEncoder[Long](LongType)
+ case object PrimitiveFloatEncoder extends
PrimitiveLeafEncoder[Float](FloatType)
+ case object PrimitiveDoubleEncoder extends
PrimitiveLeafEncoder[Double](DoubleType)
// Primitive wrapper encoders.
- case object NullEncoder extends LeafEncoder[java.lang.Void](NullType)
- case object BoxedBooleanEncoder extends
LeafEncoder[java.lang.Boolean](BooleanType)
- case object BoxedByteEncoder extends LeafEncoder[java.lang.Byte](ByteType)
- case object BoxedShortEncoder extends LeafEncoder[java.lang.Short](ShortType)
- case object BoxedIntEncoder extends
LeafEncoder[java.lang.Integer](IntegerType)
- case object BoxedLongEncoder extends LeafEncoder[java.lang.Long](LongType)
- case object BoxedFloatEncoder extends LeafEncoder[java.lang.Float](FloatType)
- case object BoxedDoubleEncoder extends
LeafEncoder[java.lang.Double](DoubleType)
+ abstract class BoxedLeafEncoder[E : ClassTag, P](
+ dataType: DataType,
+ val primitive: PrimitiveLeafEncoder[P])
+ extends LeafEncoder[E](dataType)
+ case object BoxedBooleanEncoder
+ extends BoxedLeafEncoder[java.lang.Boolean, Boolean](BooleanType,
PrimitiveBooleanEncoder)
+ case object BoxedByteEncoder
+ extends BoxedLeafEncoder[java.lang.Byte, Byte](ByteType,
PrimitiveByteEncoder)
+ case object BoxedShortEncoder
+ extends BoxedLeafEncoder[java.lang.Short, Short](ShortType,
PrimitiveShortEncoder)
+ case object BoxedIntEncoder
+ extends BoxedLeafEncoder[java.lang.Integer, Int](IntegerType,
PrimitiveIntEncoder)
+ case object BoxedLongEncoder
+ extends BoxedLeafEncoder[java.lang.Long, Long](LongType,
PrimitiveLongEncoder)
+ case object BoxedFloatEncoder
+ extends BoxedLeafEncoder[java.lang.Float, Float](FloatType,
PrimitiveFloatEncoder)
+ case object BoxedDoubleEncoder
+ extends BoxedLeafEncoder[java.lang.Double, Double](DoubleType,
PrimitiveDoubleEncoder)
// Nullable leaf encoders
+ case object NullEncoder extends LeafEncoder[java.lang.Void](NullType)
case object StringEncoder extends LeafEncoder[String](StringType)
case object BinaryEncoder extends LeafEncoder[Array[Byte]](BinaryType)
- case object SparkDecimalEncoder extends
LeafEncoder[Decimal](DecimalType.SYSTEM_DEFAULT)
- case object ScalaDecimalEncoder extends
LeafEncoder[BigDecimal](DecimalType.SYSTEM_DEFAULT)
- case object JavaDecimalEncoder extends
LeafEncoder[JBigDecimal](DecimalType.SYSTEM_DEFAULT)
case object ScalaBigIntEncoder extends
LeafEncoder[BigInt](DecimalType.BigIntDecimal)
case object JavaBigIntEncoder extends
LeafEncoder[JBigInt](DecimalType.BigIntDecimal)
case object CalendarIntervalEncoder extends
LeafEncoder[CalendarInterval](CalendarIntervalType)
case object DayTimeIntervalEncoder extends
LeafEncoder[Duration](DayTimeIntervalType())
case object YearMonthIntervalEncoder extends
LeafEncoder[Period](YearMonthIntervalType())
- case object DateEncoder extends LeafEncoder[java.sql.Date](DateType)
- case object LocalDateEncoder extends LeafEncoder[LocalDate](DateType)
- case object TimestampEncoder extends
LeafEncoder[java.sql.Timestamp](TimestampType)
- case object InstantEncoder extends LeafEncoder[Instant](TimestampType)
+ case class DateEncoder(override val lenientSerialization: Boolean)
+ extends LeafEncoder[jsql.Date](DateType)
+ case class LocalDateEncoder(override val lenientSerialization: Boolean)
+ extends LeafEncoder[LocalDate](DateType)
+ case class TimestampEncoder(override val lenientSerialization: Boolean)
+ extends LeafEncoder[jsql.Timestamp](TimestampType)
+ case class InstantEncoder(override val lenientSerialization: Boolean)
+ extends LeafEncoder[Instant](TimestampType)
case object LocalDateTimeEncoder extends
LeafEncoder[LocalDateTime](TimestampNTZType)
+
+ case class SparkDecimalEncoder(dt: DecimalType) extends
LeafEncoder[Decimal](dt)
+ case class ScalaDecimalEncoder(dt: DecimalType) extends
LeafEncoder[BigDecimal](dt)
+ case class JavaDecimalEncoder(dt: DecimalType, override val
lenientSerialization: Boolean)
+ extends LeafEncoder[JBigDecimal](dt)
+
+ val STRICT_DATE_ENCODER: DateEncoder = DateEncoder(lenientSerialization =
false)
+ val STRICT_LOCAL_DATE_ENCODER: LocalDateEncoder =
LocalDateEncoder(lenientSerialization = false)
+ val STRICT_TIMESTAMP_ENCODER: TimestampEncoder =
TimestampEncoder(lenientSerialization = false)
+ val STRICT_INSTANT_ENCODER: InstantEncoder =
InstantEncoder(lenientSerialization = false)
+ val LENIENT_DATE_ENCODER: DateEncoder = DateEncoder(lenientSerialization =
true)
+ val LENIENT_LOCAL_DATE_ENCODER: LocalDateEncoder =
LocalDateEncoder(lenientSerialization = true)
+ val LENIENT_TIMESTAMP_ENCODER: TimestampEncoder =
TimestampEncoder(lenientSerialization = true)
+ val LENIENT_INSTANT_ENCODER: InstantEncoder =
InstantEncoder(lenientSerialization = true)
+
+ val DEFAULT_SPARK_DECIMAL_ENCODER: SparkDecimalEncoder =
+ SparkDecimalEncoder(DecimalType.SYSTEM_DEFAULT)
+ val DEFAULT_SCALA_DECIMAL_ENCODER: ScalaDecimalEncoder =
+ ScalaDecimalEncoder(DecimalType.SYSTEM_DEFAULT)
+ val DEFAULT_JAVA_DECIMAL_ENCODER: JavaDecimalEncoder =
+ JavaDecimalEncoder(DecimalType.SYSTEM_DEFAULT, lenientSerialization =
false)
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 82a6863b5ff..9ca2fc72ad9 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -47,7 +47,10 @@ import org.apache.spark.util.Utils
object ExpressionEncoder {
def apply[T : TypeTag](): ExpressionEncoder[T] = {
- val enc = ScalaReflection.encoderFor[T]
+ apply(ScalaReflection.encoderFor[T])
+ }
+
+ def apply[T](enc: AgnosticEncoder[T]): ExpressionEncoder[T] = {
new ExpressionEncoder[T](
ScalaReflection.serializerFor(enc),
ScalaReflection.deserializerFor(enc),
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index 8eb3475acaa..78243894544 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -17,19 +17,11 @@
package org.apache.spark.sql.catalyst.encoders
-import scala.annotation.tailrec
-import scala.collection.Map
-import scala.reflect.ClassTag
+import scala.collection.mutable
+import scala.reflect.classTag
import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.{ScalaReflection, WalkedTypePath}
-import org.apache.spark.sql.catalyst.DeserializerBuildHelper._
-import org.apache.spark.sql.catalyst.SerializerBuildHelper._
-import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.objects._
-import org.apache.spark.sql.catalyst.types._
-import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData}
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder,
BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder,
BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, CalendarIntervalEncoder,
DateEncoder, DayTimeIntervalEncoder, EncoderField, InstantEncoder,
IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, LocalDateTimeEncoder,
MapEncoder, NullEncoder, RowEncoder => AgnosticRowEncoder, StringEncoder,
TimestampEncoder, UDTEncoder, YearMont [...]
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -68,224 +60,46 @@ import org.apache.spark.sql.types._
*/
object RowEncoder {
def apply(schema: StructType, lenient: Boolean): ExpressionEncoder[Row] = {
- val cls = classOf[Row]
- val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
- val serializer = serializerFor(inputObject, schema, lenient)
- val deserializer = deserializerFor(GetColumnByOrdinal(0,
serializer.dataType), schema)
- new ExpressionEncoder[Row](
- serializer,
- deserializer,
- ClassTag(cls))
+ ExpressionEncoder(encoderFor(schema, lenient))
}
+
def apply(schema: StructType): ExpressionEncoder[Row] = {
apply(schema, lenient = false)
}
- private def serializerFor(
- inputObject: Expression,
- inputType: DataType,
- lenient: Boolean): Expression = inputType match {
- case dt if ScalaReflection.isNativeType(dt) => inputObject
-
- case p: PythonUserDefinedType => serializerFor(inputObject, p.sqlType,
lenient)
-
- case udt: UserDefinedType[_] =>
- val annotation = udt.userClass.getAnnotation(classOf[SQLUserDefinedType])
- val udtClass: Class[_] = if (annotation != null) {
- annotation.udt()
- } else {
- UDTRegistration.getUDTFor(udt.userClass.getName).getOrElse {
- throw
QueryExecutionErrors.userDefinedTypeNotAnnotatedAndRegisteredError(udt)
- }
- }
- val obj = NewInstance(
- udtClass,
- Nil,
- dataType = ObjectType(udtClass), false)
- Invoke(obj, "serialize", udt, inputObject :: Nil, returnNullable = false)
-
- case TimestampType =>
- if (lenient) {
- createSerializerForAnyTimestamp(inputObject)
- } else if (SQLConf.get.datetimeJava8ApiEnabled) {
- createSerializerForJavaInstant(inputObject)
- } else {
- createSerializerForSqlTimestamp(inputObject)
- }
-
- case TimestampNTZType => createSerializerForLocalDateTime(inputObject)
-
- case DateType =>
- if (lenient) {
- createSerializerForAnyDate(inputObject)
- } else if (SQLConf.get.datetimeJava8ApiEnabled) {
- createSerializerForJavaLocalDate(inputObject)
- } else {
- createSerializerForSqlDate(inputObject)
- }
-
- case _: DayTimeIntervalType => createSerializerForJavaDuration(inputObject)
-
- case _: YearMonthIntervalType => createSerializerForJavaPeriod(inputObject)
-
- case d: DecimalType =>
- CheckOverflow(StaticInvoke(
- Decimal.getClass,
- d,
- "fromDecimal",
- inputObject :: Nil,
- returnNullable = false), d, !SQLConf.get.ansiEnabled)
-
- case StringType => createSerializerForString(inputObject)
-
- case t @ ArrayType(et, containsNull) =>
- et match {
- case BooleanType | ByteType | ShortType | IntegerType | LongType |
FloatType | DoubleType =>
- StaticInvoke(
- classOf[ArrayData],
- t,
- "toArrayData",
- inputObject :: Nil,
- returnNullable = false)
-
- case _ =>
- createSerializerForMapObjects(
- inputObject,
- ObjectType(classOf[Object]),
- element => {
- val value = serializerFor(ValidateExternalType(element, et,
lenient), et, lenient)
- expressionWithNullSafety(value, containsNull, WalkedTypePath())
- })
- }
-
- case t @ MapType(kt, vt, valueNullable) =>
- val keys =
- Invoke(
- Invoke(inputObject, "keysIterator",
ObjectType(classOf[scala.collection.Iterator[_]]),
- returnNullable = false),
- "toSeq",
- ObjectType(classOf[scala.collection.Seq[_]]), returnNullable = false)
- val convertedKeys = serializerFor(keys, ArrayType(kt, false), lenient)
-
- val values =
- Invoke(
- Invoke(inputObject, "valuesIterator",
ObjectType(classOf[scala.collection.Iterator[_]]),
- returnNullable = false),
- "toSeq",
- ObjectType(classOf[scala.collection.Seq[_]]), returnNullable = false)
- val convertedValues = serializerFor(values, ArrayType(vt,
valueNullable), lenient)
-
- val nonNullOutput = NewInstance(
- classOf[ArrayBasedMapData],
- convertedKeys :: convertedValues :: Nil,
- dataType = t,
- propagateNull = false)
-
- if (inputObject.nullable) {
- expressionForNullableExpr(inputObject, nonNullOutput)
- } else {
- nonNullOutput
- }
-
- case StructType(fields) =>
- val nonNullOutput = CreateNamedStruct(fields.zipWithIndex.flatMap { case
(field, index) =>
- val fieldValue = serializerFor(
- ValidateExternalType(
- GetExternalRowField(inputObject, index, field.name),
- field.dataType,
- lenient),
- field.dataType,
- lenient)
- val convertedField = if (field.nullable) {
- If(
- Invoke(inputObject, "isNullAt", BooleanType, Literal(index) ::
Nil),
- // Because we strip UDTs, `field.dataType` can be different from
`fieldValue.dataType`.
- // We should use `fieldValue.dataType` here.
- Literal.create(null, fieldValue.dataType),
- fieldValue
- )
- } else {
- fieldValue
- }
- Literal(field.name) :: convertedField :: Nil
- })
-
- if (inputObject.nullable) {
- expressionForNullableExpr(inputObject, nonNullOutput)
- } else {
- nonNullOutput
- }
- // For other data types, return the internal catalyst value as it is.
- case _ => inputObject
- }
-
- /**
- * Returns the `DataType` that can be used when generating code that
converts input data
- * into the Spark SQL internal format. Unlike `externalDataTypeFor`, the
`DataType` returned
- * by this function can be more permissive since multiple external types may
map to a single
- * internal type. For example, for an input with DecimalType in external
row, its external types
- * can be `scala.math.BigDecimal`, `java.math.BigDecimal`, or
- * `org.apache.spark.sql.types.Decimal`.
- */
- def externalDataTypeForInput(dt: DataType, lenient: Boolean): DataType = dt
match {
- // In order to support both Decimal and java/scala BigDecimal in external
row, we make this
- // as java.lang.Object.
- case _: DecimalType => ObjectType(classOf[java.lang.Object])
- // In order to support both Array and Seq in external row, we make this as
java.lang.Object.
- case _: ArrayType => ObjectType(classOf[java.lang.Object])
- case _: DateType | _: TimestampType if lenient =>
ObjectType(classOf[java.lang.Object])
- case _ => externalDataTypeFor(dt)
- }
-
- @tailrec
- def externalDataTypeFor(dt: DataType): DataType = dt match {
- case _ if ScalaReflection.isNativeType(dt) => dt
- case TimestampType =>
- if (SQLConf.get.datetimeJava8ApiEnabled) {
- ObjectType(classOf[java.time.Instant])
- } else {
- ObjectType(classOf[java.sql.Timestamp])
- }
- case TimestampNTZType =>
- ObjectType(classOf[java.time.LocalDateTime])
- case DateType =>
- if (SQLConf.get.datetimeJava8ApiEnabled) {
- ObjectType(classOf[java.time.LocalDate])
- } else {
- ObjectType(classOf[java.sql.Date])
- }
- case _: DayTimeIntervalType => ObjectType(classOf[java.time.Duration])
- case _: YearMonthIntervalType => ObjectType(classOf[java.time.Period])
- case p: PythonUserDefinedType => externalDataTypeFor(p.sqlType)
- case udt: UserDefinedType[_] => ObjectType(udt.userClass)
- case _ => dt.physicalDataType match {
- case _: PhysicalArrayType => ObjectType(classOf[scala.collection.Seq[_]])
- case _: PhysicalDecimalType => ObjectType(classOf[java.math.BigDecimal])
- case _: PhysicalMapType => ObjectType(classOf[scala.collection.Map[_,
_]])
- case PhysicalStringType => ObjectType(classOf[java.lang.String])
- case _: PhysicalStructType => ObjectType(classOf[Row])
- // For other data types, return the data type as it is.
- case _ => dt
- }
- }
-
- private def deserializerFor(input: Expression, schema: StructType):
Expression = {
- val fields = schema.zipWithIndex.map { case (f, i) =>
- deserializerFor(GetStructField(input, i))
- }
- CreateExternalRow(fields, schema)
+ def encoderFor(schema: StructType): AgnosticEncoder[Row] = {
+ encoderFor(schema, lenient = false)
}
- private def deserializerFor(input: Expression): Expression = {
- deserializerFor(input, input.dataType)
+ def encoderFor(schema: StructType, lenient: Boolean): AgnosticEncoder[Row] =
{
+ encoderForDataType(schema, lenient).asInstanceOf[AgnosticEncoder[Row]]
}
- @tailrec
- private def deserializerFor(input: Expression, dataType: DataType):
Expression = dataType match {
- case dt if ScalaReflection.isNativeType(dt) => input
-
- case p: PythonUserDefinedType => deserializerFor(input, p.sqlType)
-
+ private[catalyst] def encoderForDataType(
+ dataType: DataType,
+ lenient: Boolean): AgnosticEncoder[_] = dataType match {
+ case NullType => NullEncoder
+ case BooleanType => BoxedBooleanEncoder
+ case ByteType => BoxedByteEncoder
+ case ShortType => BoxedShortEncoder
+ case IntegerType => BoxedIntEncoder
+ case LongType => BoxedLongEncoder
+ case FloatType => BoxedFloatEncoder
+ case DoubleType => BoxedDoubleEncoder
+ case dt: DecimalType => JavaDecimalEncoder(dt, lenientSerialization = true)
+ case BinaryType => BinaryEncoder
+ case StringType => StringEncoder
+ case TimestampType if SQLConf.get.datetimeJava8ApiEnabled =>
InstantEncoder(lenient)
+ case TimestampType => TimestampEncoder(lenient)
+ case TimestampNTZType => LocalDateTimeEncoder
+ case DateType if SQLConf.get.datetimeJava8ApiEnabled =>
LocalDateEncoder(lenient)
+ case DateType => DateEncoder(lenient)
+ case CalendarIntervalType => CalendarIntervalEncoder
+ case _: DayTimeIntervalType => DayTimeIntervalEncoder
+ case _: YearMonthIntervalType => YearMonthIntervalEncoder
+ case p: PythonUserDefinedType =>
+ // TODO check if this works.
+ encoderForDataType(p.sqlType, lenient)
case udt: UserDefinedType[_] =>
val annotation = udt.userClass.getAnnotation(classOf[SQLUserDefinedType])
val udtClass: Class[_] = if (annotation != null) {
@@ -295,84 +109,26 @@ object RowEncoder {
throw
QueryExecutionErrors.userDefinedTypeNotAnnotatedAndRegisteredError(udt)
}
}
- val obj = NewInstance(
- udtClass,
- Nil,
- dataType = ObjectType(udtClass))
- Invoke(obj, "deserialize", ObjectType(udt.userClass), input :: Nil)
-
- case TimestampType =>
- if (SQLConf.get.datetimeJava8ApiEnabled) {
- createDeserializerForInstant(input)
- } else {
- createDeserializerForSqlTimestamp(input)
- }
-
- case TimestampNTZType =>
- createDeserializerForLocalDateTime(input)
-
- case DateType =>
- if (SQLConf.get.datetimeJava8ApiEnabled) {
- createDeserializerForLocalDate(input)
- } else {
- createDeserializerForSqlDate(input)
- }
-
- case _: DayTimeIntervalType => createDeserializerForDuration(input)
-
- case _: YearMonthIntervalType => createDeserializerForPeriod(input)
-
- case _: DecimalType => createDeserializerForJavaBigDecimal(input,
returnNullable = false)
-
- case StringType => createDeserializerForString(input, returnNullable =
false)
-
- case ArrayType(et, nullable) =>
- val arrayData =
- Invoke(
- MapObjects(deserializerFor(_), input, et),
- "array",
- ObjectType(classOf[Array[_]]), returnNullable = false)
- // TODO should use `scala.collection.immutable.ArrayDeq.unsafeMake`
method to create
- // `immutable.Seq` in Scala 2.13 when Scala version compatibility is no
longer required.
- StaticInvoke(
- scala.collection.mutable.WrappedArray.getClass,
- ObjectType(classOf[scala.collection.Seq[_]]),
- "make",
- arrayData :: Nil,
- returnNullable = false)
-
- case MapType(kt, vt, valueNullable) =>
- val keyArrayType = ArrayType(kt, false)
- val keyData = deserializerFor(Invoke(input, "keyArray", keyArrayType))
-
- val valueArrayType = ArrayType(vt, valueNullable)
- val valueData = deserializerFor(Invoke(input, "valueArray",
valueArrayType))
-
- StaticInvoke(
- ArrayBasedMapData.getClass,
- ObjectType(classOf[Map[_, _]]),
- "toScalaMap",
- keyData :: valueData :: Nil,
- returnNullable = false)
-
- case schema @ StructType(fields) =>
- val convertedFields = fields.zipWithIndex.map { case (f, i) =>
- If(
- Invoke(input, "isNullAt", BooleanType, Literal(i) :: Nil),
- Literal.create(null, externalDataTypeFor(f.dataType)),
- deserializerFor(GetStructField(input, i)))
- }
- If(IsNull(input),
- Literal.create(null, externalDataTypeFor(input.dataType)),
- CreateExternalRow(convertedFields, schema))
-
- // For other data types, return the internal catalyst value as it is.
- case _ => input
- }
-
- private def expressionForNullableExpr(
- expr: Expression,
- newExprWhenNotNull: Expression): Expression = {
- If(IsNull(expr), Literal.create(null, newExprWhenNotNull.dataType),
newExprWhenNotNull)
+ UDTEncoder(udt, udtClass.asInstanceOf[Class[_ <: UserDefinedType[_]]])
+ case ArrayType(elementType, containsNull) =>
+ IterableEncoder(
+ classTag[mutable.WrappedArray[_]],
+ encoderForDataType(elementType, lenient),
+ containsNull,
+ lenientSerialization = true)
+ case MapType(keyType, valueType, valueContainsNull) =>
+ MapEncoder(
+ classTag[scala.collection.Map[_, _]],
+ encoderForDataType(keyType, lenient),
+ encoderForDataType(valueType, lenient),
+ valueContainsNull)
+ case StructType(fields) =>
+ AgnosticRowEncoder(fields.map { field =>
+ EncoderField(
+ field.name,
+ encoderForDataType(field.dataType, lenient),
+ field.nullable,
+ field.metadata)
+ })
}
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index a644b90a96f..56facda2af6 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -20,9 +20,10 @@ package org.apache.spark.sql.catalyst.expressions.objects
import java.lang.reflect.{Method, Modifier}
import scala.collection.JavaConverters._
+import scala.collection.mutable
import scala.collection.mutable.{Builder, WrappedArray}
import scala.reflect.ClassTag
-import scala.util.{Properties, Try}
+import scala.util.Try
import org.apache.commons.lang3.reflect.MethodUtils
@@ -30,7 +31,6 @@ import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.serializer._
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection}
-import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
@@ -859,7 +859,7 @@ case class MapObjects private(
case _ => inputData.dataType
}
- private def executeFuncOnCollection(inputCollection: Seq[_]): Iterator[_] = {
+ private def executeFuncOnCollection(inputCollection: Iterable[_]):
Iterator[_] = {
val row = new GenericInternalRow(1)
inputCollection.iterator.map { element =>
row.update(0, element)
@@ -867,7 +867,7 @@ case class MapObjects private(
}
}
- private lazy val convertToSeq: Any => Seq[_] = inputDataType match {
+ private lazy val convertToSeq: Any => scala.collection.Seq[_] =
inputDataType match {
case ObjectType(cls) if
classOf[scala.collection.Seq[_]].isAssignableFrom(cls) =>
_.asInstanceOf[scala.collection.Seq[_]].toSeq
case ObjectType(cls) if cls.isArray =>
@@ -879,17 +879,33 @@ case class MapObjects private(
if (inputCollection.getClass.isArray) {
inputCollection.asInstanceOf[Array[_]].toSeq
} else {
- inputCollection.asInstanceOf[Seq[_]]
+ inputCollection.asInstanceOf[scala.collection.Seq[_]]
}
}
case ArrayType(et, _) =>
_.asInstanceOf[ArrayData].toSeq[Any](et)
}
- private lazy val mapElements: Seq[_] => Any = customCollectionCls match {
+ private def elementClassTag(): ClassTag[Any] = {
+ val clazz = lambdaFunction.dataType match {
+ case ObjectType(cls) => cls
+ case dt if lambdaFunction.nullable => ScalaReflection.javaBoxedType(dt)
+ case dt => ScalaReflection.dataTypeJavaClass(dt)
+ }
+ ClassTag(clazz).asInstanceOf[ClassTag[Any]]
+ }
+
+ private lazy val mapElements: scala.collection.Seq[_] => Any =
customCollectionCls match {
case Some(cls) if classOf[WrappedArray[_]].isAssignableFrom(cls) =>
- // Scala WrappedArray
- inputCollection =>
WrappedArray.make(executeFuncOnCollection(inputCollection).toArray)
+ // The implicit tag is a workaround to deal with a small change in the
+ // (scala) signature of ArrayBuilder.make between Scala 2.12 and 2.13.
+ implicit val tag: ClassTag[Any] = elementClassTag()
+ input => {
+ val builder = mutable.ArrayBuilder.make[Any]
+ builder.sizeHint(input.size)
+ executeFuncOnCollection(input).foreach(builder += _)
+ mutable.WrappedArray.make(builder.result())
+ }
case Some(cls) if classOf[scala.collection.Seq[_]].isAssignableFrom(cls) =>
// Scala sequence
executeFuncOnCollection(_).toSeq
@@ -1047,44 +1063,20 @@ case class MapObjects private(
val (initCollection, addElement, getResult): (String, String => String,
String) =
customCollectionCls match {
case Some(cls) if classOf[WrappedArray[_]].isAssignableFrom(cls) =>
- def doCodeGenForScala212 = {
- // WrappedArray in Scala 2.12
- val getBuilder = s"${cls.getName}$$.MODULE$$.newBuilder()"
- val builder = ctx.freshName("collectionBuilder")
- (
- s"""
- ${classOf[Builder[_, _]].getName} $builder = $getBuilder;
- $builder.sizeHint($dataLength);
- """,
- (genValue: String) => s"$builder.$$plus$$eq($genValue);",
- s"(${cls.getName}) ${classOf[WrappedArray[_]].getName}$$." +
- s"MODULE$$.make(((${classOf[IndexedSeq[_]].getName})$builder" +
-
s".result()).toArray(scala.reflect.ClassTag$$.MODULE$$.Object()));"
- )
- }
-
- def doCodeGenForScala213 = {
- // In Scala 2.13, WrappedArray is mutable.ArraySeq and newBuilder
method need
- // a ClassTag type construction parameter
- val getBuilder = s"${cls.getName}$$.MODULE$$.newBuilder(" +
- s"scala.reflect.ClassTag$$.MODULE$$.Object())"
- val builder = ctx.freshName("collectionBuilder")
- (
- s"""
+ val tag = ctx.addReferenceObj("tag", elementClassTag())
+ val builderClassName = classOf[mutable.ArrayBuilder[_]].getName
+ val getBuilder = s"$builderClassName$$.MODULE$$.make($tag)"
+ val builder = ctx.freshName("collectionBuilder")
+ (
+ s"""
${classOf[Builder[_, _]].getName} $builder = $getBuilder;
$builder.sizeHint($dataLength);
""",
- (genValue: String) => s"$builder.$$plus$$eq($genValue);",
- s"(${cls.getName})$builder.result();"
- )
- }
+ (genValue: String) => s"$builder.$$plus$$eq($genValue);",
+ s"(${cls.getName}) ${classOf[WrappedArray[_]].getName}$$." +
+ s"MODULE$$.make($builder.result());"
+ )
- val scalaVersion = Properties.versionNumberString
- if (scalaVersion.startsWith("2.12")) {
- doCodeGenForScala212
- } else {
- doCodeGenForScala213
- }
case Some(cls) if
classOf[scala.collection.Seq[_]].isAssignableFrom(cls) ||
classOf[scala.collection.Set[_]].isAssignableFrom(cls) =>
// Scala sequence or set
@@ -1908,14 +1900,14 @@ case class GetExternalRowField(
* Validates the actual data type of input expression at runtime. If it
doesn't match the
* expectation, throw an exception.
*/
-case class ValidateExternalType(child: Expression, expected: DataType,
lenient: Boolean)
+case class ValidateExternalType(child: Expression, expected: DataType,
externalDataType: DataType)
extends UnaryExpression with NonSQLExpression with ExpectsInputTypes {
override def inputTypes: Seq[AbstractDataType] =
Seq(ObjectType(classOf[Object]))
override def nullable: Boolean = child.nullable
- override val dataType: DataType =
RowEncoder.externalDataTypeForInput(expected, lenient)
+ override val dataType: DataType = externalDataType
private lazy val errMsg = s" is not a valid external type for schema of
${expected.simpleString}"
@@ -1927,7 +1919,9 @@ case class ValidateExternalType(child: Expression,
expected: DataType, lenient:
}
case _: ArrayType =>
(value: Any) => {
- value.getClass.isArray || value.isInstanceOf[Seq[_]]
+ value.getClass.isArray ||
+ value.isInstanceOf[scala.collection.Seq[_]] ||
+ value.isInstanceOf[Set[_]]
}
case _: DateType =>
(value: Any) => {
@@ -1968,7 +1962,8 @@ case class ValidateExternalType(child: Expression,
expected: DataType, lenient:
classOf[scala.math.BigDecimal],
classOf[Decimal]))
case _: ArrayType =>
- s"$obj.getClass().isArray() || $obj instanceof
${classOf[scala.collection.Seq[_]].getName}"
+ val check = genCheckTypes(Seq(classOf[scala.collection.Seq[_]],
classOf[Set[_]]))
+ s"$obj.getClass().isArray() || $check"
case _: DateType =>
genCheckTypes(Seq(classOf[java.sql.Date],
classOf[java.time.LocalDate]))
case _: TimestampType =>
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index 7e7ce29972b..f8ebdfe7676 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -25,7 +25,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.FooEnum.FooEnum
import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue
import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct,
Expression, If, SpecificInternalRow, UpCast}
-import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull,
NewInstance}
+import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull,
MapObjects, NewInstance}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
@@ -388,11 +388,10 @@ class ScalaReflectionSuite extends SparkFunSuite {
}
test("SPARK-15062: Get correct serializer for List[_]") {
- val list = List(1, 2, 3)
val serializer = serializerFor[List[Int]]
- assert(serializer.isInstanceOf[NewInstance])
- assert(serializer.asInstanceOf[NewInstance]
-
.cls.isAssignableFrom(classOf[org.apache.spark.sql.catalyst.util.GenericArrayData]))
+ assert(serializer.isInstanceOf[MapObjects])
+ val mapObjects = serializer.asInstanceOf[MapObjects]
+ assert(mapObjects.customCollectionCls.isEmpty)
}
test("SPARK 16792: Get correct deserializer for List[_]") {
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index 3a0db1ca121..c6546105231 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -480,6 +480,8 @@ class ExpressionEncoderSuite extends
CodegenInterpretedPlanTest with AnalysisTes
encodeDecodeTest(ScroogeLikeExample(1),
"SPARK-40385 class with only a companion object constructor")
+ encodeDecodeTest(Array(Set(1, 2), Set(2, 3)), "array of sets")
+
productTest(("UDT", new ExamplePoint(0.1, 0.2)))
test("AnyVal class with Any fields") {
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
index c6bddfa5eee..b133b38a559 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.encoders
+import scala.collection.mutable
import scala.util.Random
import org.apache.spark.sql.{RandomDataGenerator, Row}
@@ -310,6 +311,19 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest {
assert(e4.getMessage.contains("java.lang.String is not a valid external
type"))
}
+ private def roundTripArray[T](dt: DataType, nullable: Boolean, data:
Array[T]): Unit = {
+ val schema = new StructType().add("a", ArrayType(dt, nullable))
+ test(s"RowEncoder should return WrappedArray with properly typed array for
$schema") {
+ val encoder = RowEncoder(schema).resolveAndBind()
+ val result = fromRow(encoder, toRow(encoder,
Row(data))).getAs[mutable.WrappedArray[_]](0)
+ assert(result.array.getClass === data.getClass)
+ assert(result === data)
+ }
+ }
+
+ roundTripArray(IntegerType, nullable = false, Array(1, 2, 3).map(Int.box))
+ roundTripArray(StringType, nullable = true, Array("hello", "world", "!",
null))
+
test("SPARK-25791: Datatype of serializers should be accessible") {
val udtSQLType = new StructType().add("a", IntegerType)
val pythonUDT = new PythonUserDefinedType(udtSQLType, "pyUDT",
"serializedPyClass")
@@ -458,4 +472,14 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest {
}
}
}
+
+ test("Encoding an ArraySeq/WrappedArray in scala-2.13") {
+ val schema = new StructType()
+ .add("headers", ArrayType(new StructType()
+ .add("key", StringType)
+ .add("value", BinaryType)))
+ val encoder = RowEncoder(schema, lenient = true).resolveAndBind()
+ val data = Row(mutable.WrappedArray.make(Array(Row("key",
"value".getBytes))))
+ val row = encoder.createSerializer()(data)
+ }
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
index 737fcb1bada..265b0eeb8bd 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
@@ -332,7 +332,7 @@ class CodeGenerationSuite extends SparkFunSuite with
ExpressionEvalHelper {
ValidateExternalType(
GetExternalRowField(inputObject, index = 0, fieldName = "\"quote"),
IntegerType,
- lenient = false) :: Nil)
+ IntegerType) :: Nil)
}
test("SPARK-17160: field names are properly escaped by AssertTrue") {
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
index 2286b734477..05ab7a65a32 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
@@ -496,10 +496,11 @@ class ObjectExpressionsSuite extends SparkFunSuite with
ExpressionEvalHelper {
(java.math.BigDecimal.valueOf(10), DecimalType.BigIntDecimal),
(Array(3, 2, 1), ArrayType(IntegerType))
).foreach { case (input, dt) =>
+ val enc = RowEncoder.encoderForDataType(dt, lenient = false)
val validateType = ValidateExternalType(
GetExternalRowField(inputObject, index = 0, fieldName = "c0"),
dt,
- lenient = false)
+ ScalaReflection.lenientExternalDataTypeFor(enc))
checkObjectExprEvaluation(validateType, input,
InternalRow.fromSeq(Seq(Row(input))))
}
@@ -507,7 +508,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with
ExpressionEvalHelper {
ValidateExternalType(
GetExternalRowField(inputObject, index = 0, fieldName = "c0"),
DoubleType,
- lenient = false),
+ DoubleType),
InternalRow.fromSeq(Seq(Row(1))),
"java.lang.Integer is not a valid external type for schema of double")
}
@@ -559,10 +560,10 @@ class ObjectExpressionsSuite extends SparkFunSuite with
ExpressionEvalHelper {
ExternalMapToCatalyst(
inputObject,
- ScalaReflection.dataTypeFor(keyEnc),
+ ScalaReflection.externalDataTypeFor(keyEnc),
kvSerializerFor(keyEnc),
keyNullable = keyEnc.nullable,
- ScalaReflection.dataTypeFor(valueEnc),
+ ScalaReflection.externalDataTypeFor(valueEnc),
kvSerializerFor(valueEnc),
valueNullable = valueEnc.nullable
)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]