This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 58e5768a09b7 [SPARK-53482][SQL] MERGE INTO support for when source has
less nested field than target
58e5768a09b7 is described below
commit 58e5768a09b7f18087bf02e45aede8fa5b2a352c
Author: Szehon Ho <[email protected]>
AuthorDate: Thu Oct 30 13:08:05 2025 +0800
[SPARK-53482][SQL] MERGE INTO support for when source has less nested field
than target
### What changes were proposed in this pull request?
Support MERGE INTO where source has less fields than target. This is
already partially supported as part of:
https://github.com/apache/spark/pull/51698, but only for top level fields.
This support it even for nested fields (structs, including within other
structs, arrays, and maps)
This patch modifies the MERGE INTO assignment to re-use existing logic in
TableOutputResolver to resolve empty values in structs to null or default.
UPDATE can also benefit from this, but we can do it in a subsequent pr.
### Why are the changes needed?
For cases where source has less fields than target in MERGE INTO, it should
behave more gracefully (inserting null values where source field does not
exist).
### Does this PR introduce _any_ user-facing change?
No, only that this scenario used to fail and will now pass.
This gates on a new flag:
"spark.sql.merge.source.nested.type.coercion.enabled", enabled by default.
### How was this patch tested?
Add unit test to MergeIntoTableSuiteBase
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #52347 from szehon-ho/nested_merge_round_3.
Authored-by: Szehon Ho <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../sql/catalyst/analysis/AssignmentUtils.scala | 31 +-
.../ResolveRowLevelCommandAssignments.scala | 22 +-
.../catalyst/analysis/TableOutputResolver.scala | 69 ++-
.../org/apache/spark/sql/internal/SQLConf.scala | 12 +
.../sql/connector/MergeIntoTableSuiteBase.scala | 603 ++++++++++++++++++++-
.../command/AlignAssignmentsSuiteBase.scala | 7 +
.../command/AlignMergeAssignmentsSuite.scala | 82 ++-
7 files changed, 753 insertions(+), 73 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala
index 43631e1afc40..145c9077a4c2 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.analysis
import scala.collection.mutable
import org.apache.spark.sql.catalyst.SQLConfHelper
+import
org.apache.spark.sql.catalyst.analysis.TableOutputResolver.DefaultValueFillMode.{NONE,
RECURSE}
import org.apache.spark.sql.catalyst.expressions.{Attribute,
CreateNamedStruct, Expression, GetStructField, Literal}
import org.apache.spark.sql.catalyst.plans.logical.Assignment
import org.apache.spark.sql.catalyst.types.DataTypeUtils
@@ -49,11 +50,14 @@ object AssignmentUtils extends SQLConfHelper with
CastSupport {
*
* @param attrs table attributes
* @param assignments assignments to align
+ * @param coerceNestedTypes whether to coerce nested types to match the
target type
+ * for complex types
* @return aligned update assignments that match table attributes
*/
def alignUpdateAssignments(
attrs: Seq[Attribute],
- assignments: Seq[Assignment]): Seq[Assignment] = {
+ assignments: Seq[Assignment],
+ coerceNestedTypes: Boolean): Seq[Assignment] = {
val errors = new mutable.ArrayBuffer[String]()
@@ -63,7 +67,8 @@ object AssignmentUtils extends SQLConfHelper with CastSupport
{
colExpr = attr,
assignments,
addError = err => errors += err,
- colPath = Seq(attr.name))
+ colPath = Seq(attr.name),
+ coerceNestedTypes)
}
if (errors.nonEmpty) {
@@ -84,11 +89,14 @@ object AssignmentUtils extends SQLConfHelper with
CastSupport {
*
* @param attrs table attributes
* @param assignments insert assignments to align
+ * @param coerceNestedTypes whether to coerce nested types to match the
target type
+ * for complex types
* @return aligned insert assignments that match table attributes
*/
def alignInsertAssignments(
attrs: Seq[Attribute],
- assignments: Seq[Assignment]): Seq[Assignment] = {
+ assignments: Seq[Assignment],
+ coerceNestedTypes: Boolean = false): Seq[Assignment] = {
val errors = new mutable.ArrayBuffer[String]()
@@ -120,8 +128,9 @@ object AssignmentUtils extends SQLConfHelper with
CastSupport {
val colPath = Seq(attr.name)
val actualAttr = restoreActualType(attr)
val value = matchingAssignments.head.value
+ val coerceMode = if (coerceNestedTypes) RECURSE else NONE
TableOutputResolver.resolveUpdate(
- "", value, actualAttr, conf, err => errors += err, colPath)
+ "", value, actualAttr, conf, err => errors += err, colPath,
coerceMode)
}
Assignment(attr, resolvedValue)
}
@@ -142,7 +151,8 @@ object AssignmentUtils extends SQLConfHelper with
CastSupport {
colExpr: Expression,
assignments: Seq[Assignment],
addError: String => Unit,
- colPath: Seq[String]): Expression = {
+ colPath: Seq[String],
+ coerceNestedTypes: Boolean = false): Expression = {
val (exactAssignments, otherAssignments) = assignments.partition {
assignment =>
assignment.key.semanticEquals(colExpr)
@@ -165,9 +175,10 @@ object AssignmentUtils extends SQLConfHelper with
CastSupport {
TableOutputResolver.checkNullability(colExpr, col, conf, colPath)
} else if (exactAssignments.nonEmpty) {
val value = exactAssignments.head.value
- TableOutputResolver.resolveUpdate("", value, col, conf, addError,
colPath)
+ val coerceMode = if (coerceNestedTypes) RECURSE else NONE
+ TableOutputResolver.resolveUpdate("", value, col, conf, addError,
colPath, coerceMode)
} else {
- applyFieldAssignments(col, colExpr, fieldAssignments, addError, colPath)
+ applyFieldAssignments(col, colExpr, fieldAssignments, addError, colPath,
coerceNestedTypes)
}
}
@@ -176,7 +187,8 @@ object AssignmentUtils extends SQLConfHelper with
CastSupport {
colExpr: Expression,
assignments: Seq[Assignment],
addError: String => Unit,
- colPath: Seq[String]): Expression = {
+ colPath: Seq[String],
+ coerceNestedTyptes: Boolean): Expression = {
col.dataType match {
case structType: StructType =>
@@ -185,7 +197,8 @@ object AssignmentUtils extends SQLConfHelper with
CastSupport {
GetStructField(colExpr, ordinal, Some(field.name))
}
val updatedFieldExprs = fieldAttrs.zip(fieldExprs).map { case
(fieldAttr, fieldExpr) =>
- applyAssignments(fieldAttr, fieldExpr, assignments, addError,
colPath :+ fieldAttr.name)
+ applyAssignments(fieldAttr, fieldExpr, assignments, addError,
colPath :+ fieldAttr.name,
+ coerceNestedTyptes)
}
toNamedStruct(structType, updatedFieldExprs)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveRowLevelCommandAssignments.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveRowLevelCommandAssignments.scala
index 83520b780f12..3eb528954b35 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveRowLevelCommandAssignments.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveRowLevelCommandAssignments.scala
@@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.COMMAND
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy
/**
@@ -42,7 +43,8 @@ object ResolveRowLevelCommandAssignments extends
Rule[LogicalPlan] {
case u: UpdateTable if !u.skipSchemaResolution && u.resolved &&
u.rewritable && !u.aligned =>
validateStoreAssignmentPolicy()
val newTable = cleanAttrMetadata(u.table)
- val newAssignments =
AssignmentUtils.alignUpdateAssignments(u.table.output, u.assignments)
+ val newAssignments =
AssignmentUtils.alignUpdateAssignments(u.table.output, u.assignments,
+ coerceNestedTypes = false)
u.copy(table = newTable, assignments = newAssignments)
case u: UpdateTable if !u.skipSchemaResolution && u.resolved && !u.aligned
=>
@@ -51,11 +53,14 @@ object ResolveRowLevelCommandAssignments extends
Rule[LogicalPlan] {
case m: MergeIntoTable if !m.skipSchemaResolution && m.resolved &&
m.rewritable && !m.aligned &&
!m.needSchemaEvolution =>
validateStoreAssignmentPolicy()
+ val coerceNestedTypes = SQLConf.get.coerceMergeNestedTypes
m.copy(
targetTable = cleanAttrMetadata(m.targetTable),
- matchedActions = alignActions(m.targetTable.output, m.matchedActions),
- notMatchedActions = alignActions(m.targetTable.output,
m.notMatchedActions),
- notMatchedBySourceActions = alignActions(m.targetTable.output,
m.notMatchedBySourceActions))
+ matchedActions = alignActions(m.targetTable.output, m.matchedActions,
coerceNestedTypes),
+ notMatchedActions = alignActions(m.targetTable.output,
m.notMatchedActions,
+ coerceNestedTypes),
+ notMatchedBySourceActions = alignActions(m.targetTable.output,
m.notMatchedBySourceActions,
+ coerceNestedTypes))
case m: MergeIntoTable if !m.skipSchemaResolution && m.resolved &&
!m.aligned
&& !m.needSchemaEvolution =>
@@ -109,14 +114,17 @@ object ResolveRowLevelCommandAssignments extends
Rule[LogicalPlan] {
private def alignActions(
attrs: Seq[Attribute],
- actions: Seq[MergeAction]): Seq[MergeAction] = {
+ actions: Seq[MergeAction],
+ coerceNestedTypes: Boolean): Seq[MergeAction] = {
actions.map {
case u @ UpdateAction(_, assignments) =>
- u.copy(assignments = AssignmentUtils.alignUpdateAssignments(attrs,
assignments))
+ u.copy(assignments = AssignmentUtils.alignUpdateAssignments(attrs,
assignments,
+ coerceNestedTypes))
case d: DeleteAction =>
d
case i @ InsertAction(_, assignments) =>
- i.copy(assignments = AssignmentUtils.alignInsertAssignments(attrs,
assignments))
+ i.copy(assignments = AssignmentUtils.alignInsertAssignments(attrs,
assignments,
+ coerceNestedTypes))
case other =>
throw new AnalysisException(
errorClass = "_LEGACY_ERROR_TEMP_3052",
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala
index 59d015b8ee13..7eacc5ab9b2a 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala
@@ -22,6 +22,7 @@ import scala.collection.mutable
import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.SQLConfHelper
+import
org.apache.spark.sql.catalyst.analysis.TableOutputResolver.DefaultValueFillMode.{FILL,
NONE, RECURSE}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
@@ -38,6 +39,17 @@ import org.apache.spark.sql.types.{ArrayType, DataType,
DecimalType, IntegralTyp
object TableOutputResolver extends SQLConfHelper with Logging {
+ /**
+ * Modes for filling in default or null values for missing columns.
+ * If FILL, fill missing top-level columns with their default values.
+ * If RECURSE, fill missing top-level columns and also recurse into nested
struct
+ * fields to fill null.
+ * If NONE, do not fill any missing columns.
+ */
+ object DefaultValueFillMode extends Enumeration {
+ val FILL, RECURSE, NONE = Value
+ }
+
def resolveVariableOutputColumns(
expected: Seq[VariableReference],
query: LogicalPlan,
@@ -90,15 +102,17 @@ object TableOutputResolver extends SQLConfHelper with
Logging {
val errors = new mutable.ArrayBuffer[String]()
val resolved: Seq[NamedExpression] = if (byName) {
// If a top-level column does not have a corresponding value in the
input query, fill with
- // the column's default value. We need to pass `fillDefaultValue` as
true here, if the
+ // the column's default value. We need to pass `fillDefaultValue` as
FILL here, if the
// `supportColDefaultValue` parameter is also true.
+ val defaultValueFillMode = if (supportColDefaultValue) FILL else NONE
reorderColumnsByName(
tableName,
query.output,
expected,
conf,
errors += _,
- fillDefaultValue = supportColDefaultValue)
+ Nil,
+ defaultValueFillMode)
} else {
if (expected.size > query.output.size) {
throw QueryCompilationErrors.cannotWriteNotEnoughColumnsToTableError(
@@ -125,8 +139,10 @@ object TableOutputResolver extends SQLConfHelper with
Logging {
col: Attribute,
conf: SQLConf,
addError: String => Unit,
- colPath: Seq[String]): Expression = {
+ colPath: Seq[String],
+ defaultValueFillMode: DefaultValueFillMode.Value): Expression = {
+ val fillChildDefaultValue = defaultValueFillMode == RECURSE
(value.dataType, col.dataType) match {
// no need to reorder inner fields or cast if types are already
compatible
case (valueType, colType) if
DataType.equalsIgnoreCompatibleNullability(valueType, colType) =>
@@ -141,17 +157,17 @@ object TableOutputResolver extends SQLConfHelper with
Logging {
case (valueType: StructType, colType: StructType) =>
val resolvedValue = resolveStructType(
tableName, value, valueType, col, colType,
- byName = true, conf, addError, colPath)
+ byName = true, conf, addError, colPath, fillChildDefaultValue)
resolvedValue.getOrElse(value)
case (valueType: ArrayType, colType: ArrayType) =>
val resolvedValue = resolveArrayType(
tableName, value, valueType, col, colType,
- byName = true, conf, addError, colPath)
+ byName = true, conf, addError, colPath, fillChildDefaultValue)
resolvedValue.getOrElse(value)
case (valueType: MapType, colType: MapType) =>
val resolvedValue = resolveMapType(
tableName, value, valueType, col, colType,
- byName = true, conf, addError, colPath)
+ byName = true, conf, addError, colPath, fillChildDefaultValue)
resolvedValue.getOrElse(value)
case _ =>
checkUpdate(tableName, value, col, conf, addError, colPath)
@@ -288,13 +304,13 @@ object TableOutputResolver extends SQLConfHelper with
Logging {
conf: SQLConf,
addError: String => Unit,
colPath: Seq[String] = Nil,
- fillDefaultValue: Boolean = false): Seq[NamedExpression] = {
+ defaultValueFillMode: DefaultValueFillMode.Value): Seq[NamedExpression]
= {
val matchedCols = mutable.HashSet.empty[String]
val reordered = expectedCols.flatMap { expectedCol =>
val matched = inputCols.filter(col => conf.resolver(col.name,
expectedCol.name))
val newColPath = colPath :+ expectedCol.name
if (matched.isEmpty) {
- val defaultExpr = if (fillDefaultValue) {
+ val defaultExpr = if (Set(FILL,
RECURSE).contains(defaultValueFillMode)) {
getDefaultValueExprOrNullLit(expectedCol,
conf.useNullsForMissingDefaultColumnValues)
} else {
None
@@ -315,19 +331,20 @@ object TableOutputResolver extends SQLConfHelper with
Logging {
val actualExpectedCol = expectedCol.withDataType {
CharVarcharUtils.getRawType(expectedCol.metadata).getOrElse(expectedCol.dataType)
}
+ val childFillDefaultValue = defaultValueFillMode == RECURSE
(matchedCol.dataType, actualExpectedCol.dataType) match {
case (matchedType: StructType, expectedType: StructType) =>
resolveStructType(
tableName, matchedCol, matchedType, actualExpectedCol,
expectedType,
- byName = true, conf, addError, newColPath)
+ byName = true, conf, addError, newColPath, childFillDefaultValue)
case (matchedType: ArrayType, expectedType: ArrayType) =>
resolveArrayType(
tableName, matchedCol, matchedType, actualExpectedCol,
expectedType,
- byName = true, conf, addError, newColPath)
+ byName = true, conf, addError, newColPath, childFillDefaultValue)
case (matchedType: MapType, expectedType: MapType) =>
resolveMapType(
tableName, matchedCol, matchedType, actualExpectedCol,
expectedType,
- byName = true, conf, addError, newColPath)
+ byName = true, conf, addError, newColPath, childFillDefaultValue)
case _ =>
checkField(
tableName, actualExpectedCol, matchedCol, byName = true, conf,
addError, newColPath)
@@ -396,15 +413,15 @@ object TableOutputResolver extends SQLConfHelper with
Logging {
case (inputType: StructType, expectedType: StructType) =>
resolveStructType(
tableName, inputCol, inputType, expectedCol, expectedType,
- byName = false, conf, addError, newColPath)
+ byName = false, conf, addError, newColPath, fillDefaultValue =
false)
case (inputType: ArrayType, expectedType: ArrayType) =>
resolveArrayType(
tableName, inputCol, inputType, expectedCol, expectedType,
- byName = false, conf, addError, newColPath)
+ byName = false, conf, addError, newColPath, fillDefaultValue =
false)
case (inputType: MapType, expectedType: MapType) =>
resolveMapType(
tableName, inputCol, inputType, expectedCol, expectedType,
- byName = false, conf, addError, newColPath)
+ byName = false, conf, addError, newColPath, fillDefaultValue =
false)
case _ =>
checkField(tableName, expectedCol, inputCol, byName = false, conf,
addError, newColPath)
}
@@ -439,13 +456,16 @@ object TableOutputResolver extends SQLConfHelper with
Logging {
byName: Boolean,
conf: SQLConf,
addError: String => Unit,
- colPath: Seq[String]): Option[NamedExpression] = {
+ colPath: Seq[String],
+ fillDefaultValue: Boolean): Option[NamedExpression] = {
val nullCheckedInput = checkNullability(input, expected, conf, colPath)
val fields = inputType.zipWithIndex.map { case (f, i) =>
Alias(GetStructField(nullCheckedInput, i, Some(f.name)), f.name)()
}
+ val defaultValueMode = if (fillDefaultValue) RECURSE else NONE
val resolved = if (byName) {
- reorderColumnsByName(tableName, fields, toAttributes(expectedType),
conf, addError, colPath)
+ reorderColumnsByName(tableName, fields, toAttributes(expectedType),
conf, addError, colPath,
+ defaultValueMode)
} else {
resolveColumnsByPosition(
tableName, fields, toAttributes(expectedType), conf, addError, colPath)
@@ -472,13 +492,16 @@ object TableOutputResolver extends SQLConfHelper with
Logging {
byName: Boolean,
conf: SQLConf,
addError: String => Unit,
- colPath: Seq[String]): Option[NamedExpression] = {
+ colPath: Seq[String],
+ fillDefaultValue: Boolean): Option[NamedExpression] = {
val nullCheckedInput = checkNullability(input, expected, conf, colPath)
val param = NamedLambdaVariable("element", inputType.elementType,
inputType.containsNull)
val fakeAttr =
AttributeReference("element", expectedType.elementType,
expectedType.containsNull)()
val res = if (byName) {
- reorderColumnsByName(tableName, Seq(param), Seq(fakeAttr), conf,
addError, colPath)
+ val defaultValueMode = if (fillDefaultValue) RECURSE else NONE
+ reorderColumnsByName(tableName, Seq(param), Seq(fakeAttr), conf,
addError, colPath,
+ defaultValueMode)
} else {
resolveColumnsByPosition(tableName, Seq(param), Seq(fakeAttr), conf,
addError, colPath)
}
@@ -506,13 +529,16 @@ object TableOutputResolver extends SQLConfHelper with
Logging {
byName: Boolean,
conf: SQLConf,
addError: String => Unit,
- colPath: Seq[String]): Option[NamedExpression] = {
+ colPath: Seq[String],
+ fillDefaultValue: Boolean): Option[NamedExpression] = {
val nullCheckedInput = checkNullability(input, expected, conf, colPath)
val keyParam = NamedLambdaVariable("key", inputType.keyType, nullable =
false)
val fakeKeyAttr = AttributeReference("key", expectedType.keyType, nullable
= false)()
+ val defaultValueFillMode = if (fillDefaultValue) RECURSE else NONE
val resKey = if (byName) {
- reorderColumnsByName(tableName, Seq(keyParam), Seq(fakeKeyAttr), conf,
addError, colPath)
+ reorderColumnsByName(tableName, Seq(keyParam), Seq(fakeKeyAttr), conf,
addError, colPath,
+ defaultValueFillMode)
} else {
resolveColumnsByPosition(tableName, Seq(keyParam), Seq(fakeKeyAttr),
conf, addError, colPath)
}
@@ -522,7 +548,8 @@ object TableOutputResolver extends SQLConfHelper with
Logging {
val fakeValueAttr =
AttributeReference("value", expectedType.valueType,
expectedType.valueContainsNull)()
val resValue = if (byName) {
- reorderColumnsByName(tableName, Seq(valueParam), Seq(fakeValueAttr),
conf, addError, colPath)
+ reorderColumnsByName(tableName, Seq(valueParam), Seq(fakeValueAttr),
conf, addError, colPath,
+ defaultValueFillMode)
} else {
resolveColumnsByPosition(
tableName, Seq(valueParam), Seq(fakeValueAttr), conf, addError,
colPath)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 727259cb69e5..05c4b10879e9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -6486,6 +6486,15 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
+ val MERGE_INTO_SOURCE_NESTED_TYPE_COERCION_ENABLED =
+ buildConf("spark.sql.merge.source.nested.type.coercion.enabled")
+ .internal()
+ .doc("If enabled, allow MERGE INTO to coerce source nested types if they
have less" +
+ "nested fields than the target table's nested types.")
+ .version("4.1.0")
+ .booleanConf
+ .createWithDefault(true)
+
/**
* Holds information about keys that have been deprecated.
*
@@ -7619,6 +7628,9 @@ class SQLConf extends Serializable with Logging with
SqlApiConf {
def legacyXMLParserEnabled: Boolean =
getConf(SQLConf.LEGACY_XML_PARSER_ENABLED)
+ def coerceMergeNestedTypes: Boolean =
+ getConf(SQLConf.MERGE_INTO_SOURCE_NESTED_TYPE_COERCION_ENABLED)
+
/** ********************** SQLConf functionality methods ************ */
/** Set Spark SQL configuration properties. */
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala
b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala
index 2e175951851a..98706c4afeae 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala
@@ -2860,7 +2860,7 @@ abstract class MergeIntoTableSuiteBase extends
RowLevelOperationSuiteBase
}
}
- test("merge into schema evolution add column with nested field and set
explicit columns") {
+ test("merge into schema evolution add column with nested struct and set
explicit columns") {
Seq(true, false).foreach { withSchemaEvolution =>
withTempView("source") {
createAndInitTable(
@@ -3041,9 +3041,7 @@ abstract class MergeIntoTableSuiteBase extends
RowLevelOperationSuiteBase
}
}
- // currently the source struct needs to be fully compatible with target
struct
- // i.e. cannot remove a nested field
- test("merge into schema evolution replace column with nested field and set
all columns") {
+ test("merge into schema evolution replace column with nested struct and set
all columns") {
Seq(true, false).foreach { withSchemaEvolution =>
withTempView("source") {
createAndInitTable(
@@ -3057,7 +3055,7 @@ abstract class MergeIntoTableSuiteBase extends
RowLevelOperationSuiteBase
StructField("s", StructType(Seq(
StructField("c1", IntegerType),
StructField("c2", StructType(Seq(
- // removed column 'a'
+ // missing column 'a'
StructField("m", MapType(StringType, StringType)),
StructField("c3", BooleanType) // new column
)))
@@ -3072,21 +3070,31 @@ abstract class MergeIntoTableSuiteBase extends
RowLevelOperationSuiteBase
.createOrReplaceTempView("source")
val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA
EVOLUTION" else ""
- val exception = intercept[org.apache.spark.sql.AnalysisException] {
- sql(
- s"""MERGE $schemaEvolutionClause
- |INTO $tableNameAsString t
- |USING source src
- |ON t.pk = src.pk
- |WHEN MATCHED THEN
- | UPDATE SET *
- |WHEN NOT MATCHED THEN
- | INSERT *
- |""".stripMargin)
+ val mergeStmt =
+ s"""MERGE $schemaEvolutionClause
+ |INTO $tableNameAsString t
+ |USING source src
+ |ON t.pk = src.pk
+ |WHEN MATCHED THEN
+ | UPDATE SET *
+ |WHEN NOT MATCHED THEN
+ | INSERT *
+ |""".stripMargin
+ if (withSchemaEvolution) {
+ sql(mergeStmt)
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, Row(10, Row(null, Map("c" -> "d"), false)), "sales"),
+ Row(2, Row(20, Row(null, Map("e" -> "f"), true)),
"engineering")))
+ } else {
+ val exception = intercept[org.apache.spark.sql.AnalysisException] {
+ sql(mergeStmt)
+ }
+ assert(exception.errorClass.get ==
"INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS")
+ assert(exception.getMessage.contains(
+ "Cannot write extra fields `c3` to the struct `s`.`c2`"))
}
-
- assert(exception.errorClass.get ==
"INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA")
- assert(exception.getMessage.contains("Cannot find data for the output
column `s`.`c2`.`a`"))
}
sql(s"DROP TABLE IF EXISTS $tableNameAsString")
}
@@ -3213,6 +3221,253 @@ abstract class MergeIntoTableSuiteBase extends
RowLevelOperationSuiteBase
}
}
+ test("merge into schema evolution replace column for struct in map and set
all columns") {
+ Seq(true, false).foreach { withSchemaEvolution =>
+ withTempView("source") {
+ val schema =
+ StructType(Seq(
+ StructField("pk", IntegerType, nullable = false),
+ StructField("m", MapType(
+ StructType(Seq(StructField("c1", IntegerType), StructField("c2",
IntegerType))),
+ StructType(Seq(StructField("c4", StringType), StructField("c5",
StringType))))),
+ StructField("dep", StringType)))
+ createTable(CatalogV2Util.structTypeToV2Columns(schema))
+
+ val data = Seq(
+ Row(0, Map(Row(10, 10) -> Row("c", "c")), "hr"),
+ Row(1, Map(Row(20, 20) -> Row("d", "d")), "sales"))
+ spark.createDataFrame(spark.sparkContext.parallelize(data), schema)
+ .writeTo(tableNameAsString).append()
+
+ val sourceTableSchema = StructType(Seq(
+ StructField("pk", IntegerType),
+ StructField("m", MapType(
+ StructType(Seq(StructField("c1", IntegerType), StructField("c3",
BooleanType))),
+ StructType(Seq(StructField("c4", StringType), StructField("c6",
BooleanType))))),
+ StructField("dep", StringType)))
+ val sourceData = Seq(
+ Row(1, Map(Row(10, true) -> Row("y", false)), "sales"),
+ Row(2, Map(Row(20, false) -> Row("z", true)), "engineering")
+ )
+ spark.createDataFrame(spark.sparkContext.parallelize(sourceData),
sourceTableSchema)
+ .createOrReplaceTempView("source")
+
+ val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA
EVOLUTION" else ""
+ val mergeStmt =
+ s"""MERGE $schemaEvolutionClause
+ |INTO $tableNameAsString t
+ |USING source src
+ |ON t.pk = src.pk
+ |WHEN MATCHED THEN
+ | UPDATE SET *
+ |WHEN NOT MATCHED THEN
+ | INSERT *
+ |""".stripMargin
+
+ if (withSchemaEvolution) {
+ sql(mergeStmt)
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(Row(0, Map(Row(10, 10, null) -> Row("c", "c", null)), "hr"),
+ Row(1, Map(Row(10, null, true) -> Row("y", null, false)),
"sales"),
+ Row(2, Map(Row(20, null, false) -> Row("z", null, true)),
"engineering")))
+ } else {
+ val exception = intercept[org.apache.spark.sql.AnalysisException] {
+ sql(mergeStmt)
+ }
+ assert(exception.errorClass.get ==
"INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS")
+ assert(exception.getMessage.contains(
+ "Cannot write extra fields `c3` to the struct `m`.`key`"))
+ }
+ }
+ sql(s"DROP TABLE IF EXISTS $tableNameAsString")
+ }
+ }
+
+ test("merge into schema evolution replace column for struct in map and set
explicit columns") {
+ Seq(true, false).foreach { withSchemaEvolution =>
+ withTempView("source") {
+ val schema =
+ StructType(Seq(
+ StructField("pk", IntegerType, nullable = false),
+ StructField("m", MapType(
+ StructType(Seq(StructField("c1", IntegerType), StructField("c2",
IntegerType))),
+ StructType(Seq(StructField("c4", StringType), StructField("c5",
StringType))))),
+ StructField("dep", StringType)))
+ createTable(CatalogV2Util.structTypeToV2Columns(schema))
+
+ val data = Seq(
+ Row(0, Map(Row(10, 10) -> Row("c", "c")), "hr"),
+ Row(1, Map(Row(20, 20) -> Row("d", "d")), "sales"))
+ spark.createDataFrame(spark.sparkContext.parallelize(data), schema)
+ .writeTo(tableNameAsString).append()
+
+ val sourceTableSchema = StructType(Seq(
+ StructField("pk", IntegerType),
+ StructField("m", MapType(
+ StructType(Seq(StructField("c1", IntegerType), StructField("c3",
BooleanType))),
+ StructType(Seq(StructField("c4", StringType), StructField("c6",
BooleanType))))),
+ StructField("dep", StringType)))
+ val sourceData = Seq(
+ Row(1, Map(Row(10, true) -> Row("y", false)), "sales"),
+ Row(2, Map(Row(20, false) -> Row("z", true)), "engineering")
+ )
+ spark.createDataFrame(spark.sparkContext.parallelize(sourceData),
sourceTableSchema)
+ .createOrReplaceTempView("source")
+
+ val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA
EVOLUTION" else ""
+ val mergeStmt =
+ s"""MERGE $schemaEvolutionClause
+ |INTO $tableNameAsString t
+ |USING source src
+ |ON t.pk = src.pk
+ |WHEN MATCHED THEN
+ | UPDATE SET t.m = src.m, t.dep = 'my_old_dep'
+ |WHEN NOT MATCHED THEN
+ | INSERT (pk, m, dep) VALUES (src.pk, src.m, 'my_new_dep')
+ |""".stripMargin
+
+ if (withSchemaEvolution) {
+ sql(mergeStmt)
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(Row(0, Map(Row(10, 10, null) -> Row("c", "c", null)), "hr"),
+ Row(1, Map(Row(10, null, true) -> Row("y", null, false)),
"my_old_dep"),
+ Row(2, Map(Row(20, null, false) -> Row("z", null, true)),
"my_new_dep")))
+ } else {
+ val exception = intercept[org.apache.spark.sql.AnalysisException] {
+ sql(mergeStmt)
+ }
+ assert(exception.errorClass.get ==
"INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS")
+ assert(exception.getMessage.contains(
+ "Cannot write extra fields `c3` to the struct `m`.`key`"))
+ }
+ }
+ sql(s"DROP TABLE IF EXISTS $tableNameAsString")
+ }
+ }
+
+ test("merge into schema evolution replace column for struct in array and set
all columns") {
+ Seq(true, false).foreach { withSchemaEvolution =>
+ withTempView("source") {
+ val schema =
+ StructType(Seq(
+ StructField("pk", IntegerType, nullable = false),
+ StructField("a", ArrayType(
+ StructType(Seq(StructField("c1", IntegerType), StructField("c2",
IntegerType))))),
+ StructField("dep", StringType)))
+ createTable(CatalogV2Util.structTypeToV2Columns(schema))
+
+ val data = Seq(
+ Row(0, Array(Row(10, 10)), "hr"),
+ Row(1, Array(Row(20, 20)), "sales"))
+ spark.createDataFrame(spark.sparkContext.parallelize(data), schema)
+ .writeTo(tableNameAsString).append()
+
+ val sourceTableSchema = StructType(Seq(
+ StructField("pk", IntegerType),
+ StructField("a", ArrayType(
+ StructType(Seq(StructField("c1", IntegerType), StructField("c3",
BooleanType))))),
+ StructField("dep", StringType)))
+ val sourceData = Seq(
+ Row(1, Array(Row(10, true)), "sales"),
+ Row(2, Array(Row(20, false)), "engineering")
+ )
+ spark.createDataFrame(spark.sparkContext.parallelize(sourceData),
sourceTableSchema)
+ .createOrReplaceTempView("source")
+
+ val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA
EVOLUTION" else ""
+ val mergeStmt =
+ s"""MERGE $schemaEvolutionClause
+ |INTO $tableNameAsString t
+ |USING source src
+ |ON t.pk = src.pk
+ |WHEN MATCHED THEN
+ | UPDATE SET *
+ |WHEN NOT MATCHED THEN
+ | INSERT *
+ |""".stripMargin
+
+ if (withSchemaEvolution) {
+ sql(mergeStmt)
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(Row(0, Array(Row(10, 10, null)), "hr"),
+ Row(1, Array(Row(10, null, true)), "sales"),
+ Row(2, Array(Row(20, null, false)), "engineering")))
+ } else {
+ val exception = intercept[org.apache.spark.sql.AnalysisException] {
+ sql(mergeStmt)
+ }
+ assert(exception.errorClass.get ==
"INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS")
+ assert(exception.getMessage.contains(
+ "Cannot write extra fields `c3` to the struct `a`.`element`"))
+ }
+ }
+ sql(s"DROP TABLE IF EXISTS $tableNameAsString")
+ }
+ }
+
+ test("merge into schema evolution replace column for struct in array and set
explicit columns") {
+ Seq(true, false).foreach { withSchemaEvolution =>
+ withTempView("source") {
+ val schema =
+ StructType(Seq(
+ StructField("pk", IntegerType, nullable = false),
+ StructField("a", ArrayType(
+ StructType(Seq(StructField("c1", IntegerType), StructField("c2",
IntegerType))))),
+ StructField("dep", StringType)))
+ createTable(CatalogV2Util.structTypeToV2Columns(schema))
+
+ val data = Seq(
+ Row(0, Array(Row(10, 10)), "hr"),
+ Row(1, Array(Row(20, 20)), "sales"))
+ spark.createDataFrame(spark.sparkContext.parallelize(data), schema)
+ .writeTo(tableNameAsString).append()
+
+ val sourceTableSchema = StructType(Seq(
+ StructField("pk", IntegerType),
+ StructField("a", ArrayType(
+ StructType(Seq(StructField("c1", IntegerType), StructField("c3",
BooleanType))))),
+ StructField("dep", StringType)))
+ val sourceData = Seq(
+ Row(1, Array(Row(10, true)), "sales"),
+ Row(2, Array(Row(20, false)), "engineering")
+ )
+ spark.createDataFrame(spark.sparkContext.parallelize(sourceData),
sourceTableSchema)
+ .createOrReplaceTempView("source")
+
+ val schemaEvolutionClause = if (withSchemaEvolution) "WITH SCHEMA
EVOLUTION" else ""
+ val mergeStmt =
+ s"""MERGE $schemaEvolutionClause
+ |INTO $tableNameAsString t
+ |USING source src
+ |ON t.pk = src.pk
+ |WHEN MATCHED THEN
+ | UPDATE SET t.a = src.a, t.dep = 'my_old_dep'
+ |WHEN NOT MATCHED THEN
+ | INSERT (pk, a, dep) VALUES (src.pk, src.a, 'my_new_dep')
+ |""".stripMargin
+
+ if (withSchemaEvolution) {
+ sql(mergeStmt)
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(Row(0, Array(Row(10, 10, null)), "hr"),
+ Row(1, Array(Row(10, null, true)), "my_old_dep"),
+ Row(2, Array(Row(20, null, false)), "my_new_dep")))
+ } else {
+ val exception = intercept[org.apache.spark.sql.AnalysisException] {
+ sql(mergeStmt)
+ }
+ assert(exception.errorClass.get ==
"INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_STRUCT_FIELDS")
+ assert(exception.getMessage.contains(
+ "Cannot write extra fields `c3` to the struct `a`.`element`"))
+ }
+ }
+ sql(s"DROP TABLE IF EXISTS $tableNameAsString")
+ }
+ }
test("merge into empty table with NOT MATCHED clause schema evolution") {
Seq(true, false) foreach { withSchemaEvolution =>
withTempView("source") {
@@ -3255,6 +3510,316 @@ abstract class MergeIntoTableSuiteBase extends
RowLevelOperationSuiteBase
}
}
+ test("merge into with source missing fields in top-level struct") {
+ withTempView("source") {
+ // Target table has struct with 3 fields at top level
+ createAndInitTable(
+ s"""pk INT NOT NULL,
+ |s STRUCT<c1: INT, c2: STRING, c3: BOOLEAN>,
+ |dep STRING""".stripMargin,
+ """{ "pk": 0, "s": { "c1": 1, "c2": "a", "c3": true }, "dep":
"sales"}""")
+
+ // Source table has struct with only 2 fields (c1, c2) - missing c3
+ val sourceTableSchema = StructType(Seq(
+ StructField("pk", IntegerType, nullable = false),
+ StructField("s", StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StringType)))), // missing c3 field
+ StructField("dep", StringType)))
+ val data = Seq(
+ Row(1, Row(10, "b"), "hr"),
+ Row(2, Row(20, "c"), "engineering")
+ )
+ spark.createDataFrame(spark.sparkContext.parallelize(data),
sourceTableSchema)
+ .createOrReplaceTempView("source")
+
+ sql(
+ s"""MERGE INTO $tableNameAsString t
+ |USING source src
+ |ON t.pk = src.pk
+ |WHEN MATCHED THEN
+ | UPDATE SET *
+ |WHEN NOT MATCHED THEN
+ | INSERT *
+ |""".stripMargin)
+
+ // Missing field c3 should be filled with NULL
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(0, Row(1, "a", true), "sales"),
+ Row(1, Row(10, "b", null), "hr"),
+ Row(2, Row(20, "c", null), "engineering")))
+ }
+ sql(s"DROP TABLE IF EXISTS $tableNameAsString")
+ }
+
+ test("merge into with source missing fields in struct nested in array") {
+ withTempView("source") {
+ // Target table has struct with 3 fields (c1, c2, c3) in array
+ createAndInitTable(
+ s"""pk INT NOT NULL,
+ |a ARRAY<STRUCT<c1: INT, c2: STRING, c3: BOOLEAN>>,
+ |dep STRING""".stripMargin,
+ """{ "pk": 0, "a": [ { "c1": 1, "c2": "a", "c3": true } ], "dep":
"sales" }
+ |{ "pk": 1, "a": [ { "c1": 2, "c2": "b", "c3": false } ], "dep":
"sales" }"""
+ .stripMargin)
+
+ // Source table has struct with only 2 fields (c1, c2) - missing c3
+ val sourceTableSchema = StructType(Seq(
+ StructField("pk", IntegerType, nullable = false),
+ StructField("a", ArrayType(
+ StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StringType))))), // missing c3 field
+ StructField("dep", StringType)))
+ val data = Seq(
+ Row(1, Array(Row(10, "c")), "hr"),
+ Row(2, Array(Row(30, "e")), "engineering")
+ )
+ spark.createDataFrame(spark.sparkContext.parallelize(data),
sourceTableSchema)
+ .createOrReplaceTempView("source")
+
+ sql(
+ s"""MERGE INTO $tableNameAsString t
+ |USING source src
+ |ON t.pk = src.pk
+ |WHEN MATCHED THEN
+ | UPDATE SET *
+ |WHEN NOT MATCHED THEN
+ | INSERT *
+ |""".stripMargin)
+
+ // Missing field c3 should be filled with NULL
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(0, Array(Row(1, "a", true)), "sales"),
+ Row(1, Array(Row(10, "c", null)), "hr"),
+ Row(2, Array(Row(30, "e", null)), "engineering")))
+ }
+ sql(s"DROP TABLE IF EXISTS $tableNameAsString")
+ }
+
+ test("merge into with source missing fields in struct nested in map key") {
+ withTempView("source") {
+ // Target table has struct with 2 fields in map key
+ val targetSchema =
+ StructType(Seq(
+ StructField("pk", IntegerType, nullable = false),
+ StructField("m", MapType(
+ StructType(Seq(StructField("c1", IntegerType), StructField("c2",
BooleanType))),
+ StructType(Seq(StructField("c3", StringType))))),
+ StructField("dep", StringType)))
+ createTable(CatalogV2Util.structTypeToV2Columns(targetSchema))
+
+ val targetData = Seq(
+ Row(0, Map(Row(10, true) -> Row("x")), "hr"),
+ Row(1, Map(Row(20, false) -> Row("y")), "sales"))
+ spark.createDataFrame(spark.sparkContext.parallelize(targetData),
targetSchema)
+ .writeTo(tableNameAsString).append()
+
+ // Source table has struct with only 1 field (c1) in map key - missing c2
+ val sourceTableSchema = StructType(Seq(
+ StructField("pk", IntegerType),
+ StructField("m", MapType(
+ StructType(Seq(StructField("c1", IntegerType))), // missing c2
+ StructType(Seq(StructField("c3", StringType))))),
+ StructField("dep", StringType)))
+ val sourceData = Seq(
+ Row(1, Map(Row(10) -> Row("z")), "sales"),
+ Row(2, Map(Row(20) -> Row("w")), "engineering")
+ )
+ spark.createDataFrame(spark.sparkContext.parallelize(sourceData),
sourceTableSchema)
+ .createOrReplaceTempView("source")
+
+ sql(
+ s"""MERGE INTO $tableNameAsString t
+ |USING source src
+ |ON t.pk = src.pk
+ |WHEN MATCHED THEN
+ | UPDATE SET *
+ |WHEN NOT MATCHED THEN
+ | INSERT *
+ |""".stripMargin)
+
+ // Missing field c2 should be filled with NULL
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(0, Map(Row(10, true) -> Row("x")), "hr"),
+ Row(1, Map(Row(10, null) -> Row("z")), "sales"),
+ Row(2, Map(Row(20, null) -> Row("w")), "engineering")))
+ }
+ sql(s"DROP TABLE IF EXISTS $tableNameAsString")
+ }
+
+ test("merge into with source missing fields in struct nested in map value") {
+ withTempView("source") {
+ // Target table has struct with 2 fields in map value
+ val targetSchema =
+ StructType(Seq(
+ StructField("pk", IntegerType, nullable = false),
+ StructField("m", MapType(
+ StructType(Seq(StructField("c1", IntegerType))),
+ StructType(Seq(StructField("c1", StringType), StructField("c2",
BooleanType))))),
+ StructField("dep", StringType)))
+ createTable(CatalogV2Util.structTypeToV2Columns(targetSchema))
+
+ val targetData = Seq(
+ Row(0, Map(Row(10) -> Row("x", true)), "hr"),
+ Row(1, Map(Row(20) -> Row("y", false)), "sales"))
+ spark.createDataFrame(spark.sparkContext.parallelize(targetData),
targetSchema)
+ .writeTo(tableNameAsString).append()
+
+ // Source table has struct with only 1 field (c1) in map value - missing
c2
+ val sourceTableSchema = StructType(Seq(
+ StructField("pk", IntegerType),
+ StructField("m", MapType(
+ StructType(Seq(StructField("c1", IntegerType))),
+ StructType(Seq(StructField("c1", StringType))))), // missing c2
+ StructField("dep", StringType)))
+ val sourceData = Seq(
+ Row(1, Map(Row(10) -> Row("z")), "sales"),
+ Row(2, Map(Row(20) -> Row("w")), "engineering")
+ )
+ spark.createDataFrame(spark.sparkContext.parallelize(sourceData),
sourceTableSchema)
+ .createOrReplaceTempView("source")
+
+ sql(
+ s"""MERGE INTO $tableNameAsString t
+ |USING source src
+ |ON t.pk = src.pk
+ |WHEN MATCHED THEN
+ | UPDATE SET *
+ |WHEN NOT MATCHED THEN
+ | INSERT *
+ |""".stripMargin)
+
+ // Missing field c2 should be filled with NULL
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(0, Map(Row(10) -> Row("x", true)), "hr"),
+ Row(1, Map(Row(10) -> Row("z", null)), "sales"),
+ Row(2, Map(Row(20) -> Row("w", null)), "engineering")))
+ }
+ sql(s"DROP TABLE IF EXISTS $tableNameAsString")
+ }
+
+ test("merge into with source missing fields in nested struct") {
+ Seq(true, false).foreach { nestedTypeCoercion =>
+ withSQLConf(SQLConf.MERGE_INTO_SOURCE_NESTED_TYPE_COERCION_ENABLED.key
+ -> nestedTypeCoercion.toString) {
+ withTempView("source") {
+ // Target table has nested struct: s.c1, s.c2.a, s.c2.b
+ createAndInitTable(
+ s"""pk INT NOT NULL,
+ |s STRUCT<c1: INT, c2: STRUCT<a: INT, b: BOOLEAN>>,
+ |dep STRING""".stripMargin,
+ """{ "pk": 1, "s": { "c1": 2, "c2": { "a": 10, "b": true } } }
+ |{ "pk": 2, "s": { "c1": 2, "c2": { "a": 30, "b": false } }
}""".stripMargin)
+
+ // Source table is missing field 'b' in nested struct s.c2
+ val sourceTableSchema = StructType(Seq(
+ StructField("pk", IntegerType, nullable = false),
+ StructField("s", StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StructType(Seq(
+ StructField("a", IntegerType)
+ // missing field 'b'
+ )))
+ ))),
+ StructField("dep", StringType)
+ ))
+ val data = Seq(
+ Row(1, Row(10, Row(20)), "sales"),
+ Row(2, Row(20, Row(30)), "engineering")
+ )
+ spark.createDataFrame(spark.sparkContext.parallelize(data),
sourceTableSchema)
+ .createOrReplaceTempView("source")
+
+ // Missing field b should be filled with NULL
+ val mergeStmt = s"""MERGE INTO $tableNameAsString t
+ |USING source src
+ |ON t.pk = src.pk
+ |WHEN MATCHED THEN
+ | UPDATE SET *
+ |WHEN NOT MATCHED THEN
+ | INSERT *
+ |""".stripMargin
+
+ if (nestedTypeCoercion) {
+ sql(mergeStmt)
+ checkAnswer(
+ sql(s"SELECT * FROM $tableNameAsString"),
+ Seq(
+ Row(1, Row(10, Row(20, null)), "sales"),
+ Row(2, Row(20, Row(30, null)), "engineering")))
+ } else {
+ val exception = intercept[Exception] {
+ sql(mergeStmt)
+ }
+ assert(exception.getMessage.contains(
+ """Cannot write incompatible data for the table
``""".stripMargin))
+ }
+ }
+ sql(s"DROP TABLE IF EXISTS $tableNameAsString")
+ }
+ }
+ }
+
+
+ test("merge with named_struct missing non-nullable field backup") {
+ withTempView("source") {
+ createAndInitTable(
+ s"""pk INT NOT NULL,
+ |s STRUCT<c1: INT, c2: STRING NOT NULL>,
+ |dep STRING""".stripMargin,
+ """{ "pk": 0, "s": { "c1": 1, "c2": "a" }, "dep": "sales" }
+ |{ "pk": 1, "s": { "c1": 2, "c2": "b" }, "dep": "hr" }"""
+ .stripMargin)
+
+ // Source table matches target table schema
+ val sourceTableSchema = StructType(Seq(
+ StructField("pk", IntegerType),
+ StructField("s", StructType(Seq(
+ StructField("c1", IntegerType),
+ StructField("c2", StringType, nullable = false)
+ ))),
+ StructField("dep", StringType)
+ ))
+
+ val data = Seq(
+ Row(1, Row(10, "a"), "engineering"),
+ Row(2, Row(20, "b"), "finance")
+ )
+ spark.createDataFrame(spark.sparkContext.parallelize(data),
sourceTableSchema)
+ .createOrReplaceTempView("source")
+
+ Seq(true, false).foreach { coerceNestedTypes =>
+ withSQLConf(SQLConf.MERGE_INTO_SOURCE_NESTED_TYPE_COERCION_ENABLED.key
->
+ coerceNestedTypes.toString) {
+ // Test UPDATE with named_struct missing non-nullable field c2
+ val e = intercept[AnalysisException] {
+ sql(
+ s"""MERGE INTO $tableNameAsString t USING source
+ |ON t.pk = source.pk
+ |WHEN MATCHED THEN
+ | UPDATE SET s = named_struct('c1', source.s.c1), dep =
source.dep
+ |WHEN NOT MATCHED THEN
+ | INSERT (pk, s, dep) VALUES (source.pk, named_struct('c1',
1), source.dep)
+ |""".stripMargin)
+ }
+ assert(e.errorClass.get ==
"INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA")
+ assert(e.getMessage.contains("Cannot write incompatible data for the
table ``: " +
+ "Cannot find data for the output column `s`.`c2`."))
+ }
+ }
+ }
+ sql(s"DROP TABLE IF EXISTS $tableNameAsString")
+ }
+
private def findMergeExec(query: String): MergeRowsExec = {
val plan = executeAndKeepPlan {
sql(query)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignAssignmentsSuiteBase.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignAssignmentsSuiteBase.scala
index 75837c59945f..14cf72c78dbe 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignAssignmentsSuiteBase.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignAssignmentsSuiteBase.scala
@@ -218,6 +218,13 @@ abstract class AlignAssignmentsSuiteBase extends
AnalysisTest {
}
}
+ protected def assertNoNullCheckExists(plan: LogicalPlan): Unit = {
+ val asserts = plan.expressions.flatMap(e => e.collect {
+ case assert: AssertNotNull => assert
+ })
+ assert(asserts.isEmpty, s"Must not have NOT NULL checks")
+ }
+
protected def assertNullCheckExists(plan: LogicalPlan, colPath:
Seq[String]): Unit = {
val asserts = plan.expressions.flatMap(e => e.collect {
case assert: AssertNotNull if assert.walkedTypePath == colPath => assert
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignMergeAssignmentsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignMergeAssignmentsSuite.scala
index cd099a2a9481..8420e5e4d880 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignMergeAssignmentsSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignMergeAssignmentsSuite.scala
@@ -690,20 +690,41 @@ class AlignMergeAssignmentsSuite extends
AlignAssignmentsSuiteBase {
|""".stripMargin)
assertNullCheckExists(plan4, Seq("s", "n_s", "dn_i"))
- val e = intercept[AnalysisException] {
- parseAndResolve(
- s"""MERGE INTO nested_struct_table t USING nested_struct_table src
- |ON t.i = src.i
- |$clause THEN
- | UPDATE SET s.n_s = named_struct('dn_i', 1)
- |""".stripMargin
- )
+ Seq(true, false).foreach { coerceNestedTypes =>
+
withSQLConf(SQLConf.MERGE_INTO_SOURCE_NESTED_TYPE_COERCION_ENABLED.key ->
+ coerceNestedTypes.toString) {
+ val mergeStmt =
+ s"""MERGE INTO nested_struct_table t USING nested_struct_table
src
+ |ON t.i = src.i
+ |$clause THEN
+ | UPDATE SET s.n_s = named_struct('dn_i', 2L)
+ |""".stripMargin
+ if (coerceNestedTypes) {
+ val plan5 = parseAndResolve(mergeStmt)
+ // No null check for dn_i as it is explicitly set
+ assertNoNullCheckExists(plan5)
+ } else {
+ val e = intercept[AnalysisException] {
+ parseAndResolve(mergeStmt)
+ }
+ checkError(
+ exception = e,
+ condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA",
+ parameters = Map("tableName" -> "``", "colName" ->
"`s`.`n_s`.`dn_l`")
+ )
+ }
+ }
}
- checkError(
- exception = e,
- condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA",
- parameters = Map("tableName" -> "``", "colName" ->
"`s`.`n_s`.`dn_l`")
- )
+
+ // dn_i is a required field but not provided
+ assertAnalysisException(
+ s"""MERGE INTO nested_struct_table t USING nested_struct_table src
+ |ON t.i = src.i
+ |$clause THEN
+ | UPDATE SET s.n_s = named_struct('dn_l', 2L)
+ |""".stripMargin,
+ "Cannot write incompatible data for the table ``: " +
+ "Cannot find data for the output column `s`.`n_s`.`dn_i`")
// ANSI mode does NOT allow string to int casts
assertAnalysisException(
@@ -836,19 +857,46 @@ class AlignMergeAssignmentsSuite extends
AlignAssignmentsSuiteBase {
|""".stripMargin)
assertNullCheckExists(plan4, Seq("s", "n_s", "dn_i"))
+ Seq(true, false).foreach { coerceNestedTypes =>
+
withSQLConf(SQLConf.MERGE_INTO_SOURCE_NESTED_TYPE_COERCION_ENABLED.key ->
+ coerceNestedTypes.toString) {
+ val mergeStmt =
+ s"""MERGE INTO nested_struct_table t USING nested_struct_table
src
+ |ON t.i = src.i
+ |$clause THEN
+ | UPDATE SET s.n_s = named_struct('dn_i', 1)
+ |""".stripMargin
+ if (coerceNestedTypes) {
+ val plan5 = parseAndResolve(mergeStmt)
+ // No null check for dn_i as it is explicitly set
+ assertNoNullCheckExists(plan5)
+ } else {
+ val e = intercept[AnalysisException] {
+ parseAndResolve(mergeStmt)
+ }
+ checkError(
+ exception = e,
+ condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA",
+ parameters = Map("tableName" -> "``", "colName" ->
"`s`.`n_s`.`dn_l`"))
+ }
+ }
+ }
+
+ // dn_i is a required field but not provided
val e = intercept[AnalysisException] {
parseAndResolve(
s"""MERGE INTO nested_struct_table t USING nested_struct_table src
|ON t.i = src.i
|$clause THEN
- | UPDATE SET s.n_s = named_struct('dn_i', 1)
- |""".stripMargin
- )
+ | UPDATE SET s.n_s = named_struct('dn_l', 2L)
+ |""".stripMargin)
}
checkError(
exception = e,
condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA",
- parameters = Map("tableName" -> "``", "colName" ->
"`s`.`n_s`.`dn_l`")
+ parameters = Map(
+ "tableName" -> "``",
+ "colName" -> "`s`.`n_s`.`dn_i`")
)
// strict mode does NOT allow string to int casts
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]