This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.0 by this push:
new cf5956f [SPARK-30899][SQL] CreateArray/CreateMap's data type should
not depend on SQLConf.get
cf5956f is described below
commit cf5956f058607dac1866fddbee495f0d46c19c05
Author: iRakson <[email protected]>
AuthorDate: Fri Mar 6 16:45:06 2020 +0800
[SPARK-30899][SQL] CreateArray/CreateMap's data type should not depend on
SQLConf.get
### What changes were proposed in this pull request?
Introduced a new parameter `emptyCollection` for `CreateMap` and
`CreateArray` functiion to remove dependency on SQLConf.get.
### Why are the changes needed?
This allows to avoid the issue when the configuration change between
different phases of planning, and this can silently break a query plan which
can lead to crashes or data corruption.
### Does this PR introduce any user-facing change?
No
### How was this patch tested?
Existing UTs.
Closes #27657 from iRakson/SPARK-30899.
Authored-by: iRakson <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
(cherry picked from commit cba17e07e9f15673f274de1728f6137d600026e1)
Signed-off-by: Wenchen Fan <[email protected]>
---
.../spark/sql/catalyst/analysis/TypeCoercion.scala | 8 ++---
.../catalyst/expressions/complexTypeCreator.scala | 35 +++++++++++++++++++---
.../expressions/complexTypeExtractors.scala | 4 +--
.../sql/catalyst/optimizer/ComplexTypes.scala | 8 ++---
.../optimizer/NormalizeFloatingNumbers.scala | 8 ++---
5 files changed, 45 insertions(+), 18 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index f416e8e..0a0bef6 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -553,10 +553,10 @@ object TypeCoercion {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
- case a @ CreateArray(children) if
!haveSameType(children.map(_.dataType)) =>
+ case a @ CreateArray(children, _) if
!haveSameType(children.map(_.dataType)) =>
val types = children.map(_.dataType)
findWiderCommonType(types) match {
- case Some(finalDataType) =>
CreateArray(children.map(castIfNotSameType(_, finalDataType)))
+ case Some(finalDataType) => a.copy(children.map(castIfNotSameType(_,
finalDataType)))
case None => a
}
@@ -592,7 +592,7 @@ object TypeCoercion {
case None => m
}
- case m @ CreateMap(children) if m.keys.length == m.values.length &&
+ case m @ CreateMap(children, _) if m.keys.length == m.values.length &&
(!haveSameType(m.keys.map(_.dataType)) ||
!haveSameType(m.values.map(_.dataType))) =>
val keyTypes = m.keys.map(_.dataType)
val newKeys = findWiderCommonType(keyTypes) match {
@@ -606,7 +606,7 @@ object TypeCoercion {
case None => m.values
}
- CreateMap(newKeys.zip(newValues).flatMap { case (k, v) => Seq(k, v) })
+ m.copy(newKeys.zip(newValues).flatMap { case (k, v) => Seq(k, v) })
// Promote SUM, SUM DISTINCT and AVERAGE to largest types to prevent
overflows.
case s @ Sum(e @ DecimalType()) => s // Decimal is already the biggest.
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index 4bd85d3..6c31511 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -37,16 +37,23 @@ import org.apache.spark.unsafe.types.UTF8String
> SELECT _FUNC_(1, 2, 3);
[1,2,3]
""")
-case class CreateArray(children: Seq[Expression]) extends Expression {
+case class CreateArray(children: Seq[Expression], useStringTypeWhenEmpty:
Boolean)
+ extends Expression {
+
+ def this(children: Seq[Expression]) = {
+ this(children,
SQLConf.get.getConf(SQLConf.LEGACY_CREATE_EMPTY_COLLECTION_USING_STRING_TYPE))
+ }
override def foldable: Boolean = children.forall(_.foldable)
+ override def stringArgs: Iterator[Any] = super.stringArgs.take(1)
+
override def checkInputDataTypes(): TypeCheckResult = {
TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), s"function
$prettyName")
}
private val defaultElementType: DataType = {
- if
(SQLConf.get.getConf(SQLConf.LEGACY_CREATE_EMPTY_COLLECTION_USING_STRING_TYPE))
{
+ if (useStringTypeWhenEmpty) {
StringType
} else {
NullType
@@ -79,6 +86,12 @@ case class CreateArray(children: Seq[Expression]) extends
Expression {
override def prettyName: String = "array"
}
+object CreateArray {
+ def apply(children: Seq[Expression]): CreateArray = {
+ new CreateArray(children)
+ }
+}
+
private [sql] object GenArrayData {
/**
* Return Java code pieces based on DataType and array size to allocate
ArrayData class
@@ -141,12 +154,18 @@ private [sql] object GenArrayData {
> SELECT _FUNC_(1.0, '2', 3.0, '4');
{1.0:"2",3.0:"4"}
""")
-case class CreateMap(children: Seq[Expression]) extends Expression {
+case class CreateMap(children: Seq[Expression], useStringTypeWhenEmpty:
Boolean)
+ extends Expression {
+
+ def this(children: Seq[Expression]) = {
+ this(children,
SQLConf.get.getConf(SQLConf.LEGACY_CREATE_EMPTY_COLLECTION_USING_STRING_TYPE))
+ }
+
lazy val keys = children.indices.filter(_ % 2 == 0).map(children)
lazy val values = children.indices.filter(_ % 2 != 0).map(children)
private val defaultElementType: DataType = {
- if
(SQLConf.get.getConf(SQLConf.LEGACY_CREATE_EMPTY_COLLECTION_USING_STRING_TYPE))
{
+ if (useStringTypeWhenEmpty) {
StringType
} else {
NullType
@@ -155,6 +174,8 @@ case class CreateMap(children: Seq[Expression]) extends
Expression {
override def foldable: Boolean = children.forall(_.foldable)
+ override def stringArgs: Iterator[Any] = super.stringArgs.take(1)
+
override def checkInputDataTypes(): TypeCheckResult = {
if (children.size % 2 != 0) {
TypeCheckResult.TypeCheckFailure(
@@ -215,6 +236,12 @@ case class CreateMap(children: Seq[Expression]) extends
Expression {
override def prettyName: String = "map"
}
+object CreateMap {
+ def apply(children: Seq[Expression]): CreateMap = {
+ new CreateMap(children)
+ }
+}
+
/**
* Returns a catalyst Map containing the two arrays in children expressions as
keys and values.
*/
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index e9d60ed..9c600c9d3 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -275,9 +275,9 @@ trait GetArrayItemUtil {
if (ordinal.foldable && !ordinal.nullable) {
val intOrdinal = ordinal.eval().asInstanceOf[Number].intValue()
child match {
- case CreateArray(ar) if intOrdinal < ar.length =>
+ case CreateArray(ar, _) if intOrdinal < ar.length =>
ar(intOrdinal).nullable
- case GetArrayStructFields(CreateArray(elements), field, _, _, _)
+ case GetArrayStructFields(CreateArray(elements, _), field, _, _, _)
if intOrdinal < elements.length =>
elements(intOrdinal).nullable || field.nullable
case _ =>
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala
index 28dc8e9..f79dabf 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala
@@ -41,14 +41,14 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] {
createNamedStruct.valExprs(ordinal)
// Remove redundant array indexing.
- case GetArrayStructFields(CreateArray(elems), field, ordinal, _, _) =>
+ case GetArrayStructFields(CreateArray(elems, useStringTypeWhenEmpty),
field, ordinal, _, _) =>
// Instead of selecting the field on the entire array, select it from
each member
// of the array. Pushing down the operation this way may open other
optimizations
// opportunities (i.e. struct(...,x,...).x)
- CreateArray(elems.map(GetStructField(_, ordinal, Some(field.name))))
+ CreateArray(elems.map(GetStructField(_, ordinal, Some(field.name))),
useStringTypeWhenEmpty)
// Remove redundant map lookup.
- case ga @ GetArrayItem(CreateArray(elems), IntegerLiteral(idx)) =>
+ case ga @ GetArrayItem(CreateArray(elems, _), IntegerLiteral(idx)) =>
// Instead of creating the array and then selecting one row, remove
array creation
// altogether.
if (idx >= 0 && idx < elems.size) {
@@ -58,7 +58,7 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] {
// out of bounds, mimic the runtime behavior and return null
Literal(null, ga.dataType)
}
- case GetMapValue(CreateMap(elems), key) => CaseKeyWhen(key, elems)
+ case GetMapValue(CreateMap(elems, _), key) => CaseKeyWhen(key, elems)
}
}
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala
index ea01d9e..5f94af5 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala
@@ -114,11 +114,11 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan]
{
case CreateNamedStruct(children) =>
CreateNamedStruct(children.map(normalize))
- case CreateArray(children) =>
- CreateArray(children.map(normalize))
+ case CreateArray(children, useStringTypeWhenEmpty) =>
+ CreateArray(children.map(normalize), useStringTypeWhenEmpty)
- case CreateMap(children) =>
- CreateMap(children.map(normalize))
+ case CreateMap(children, useStringTypeWhenEmpty) =>
+ CreateMap(children.map(normalize), useStringTypeWhenEmpty)
case _ if expr.dataType == FloatType || expr.dataType == DoubleType =>
KnownFloatingPointNormalized(NormalizeNaNAndZero(expr))
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]