This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 87bc64b2357 [SPARK-44138][SQL] Prohibit non-deterministic expressions,
subqueries and aggregates in MERGE conditions
87bc64b2357 is described below
commit 87bc64b235716515932903477f160f440362e436
Author: aokolnychyi <[email protected]>
AuthorDate: Thu Jun 22 13:15:30 2023 -0700
[SPARK-44138][SQL] Prohibit non-deterministic expressions, subqueries and
aggregates in MERGE conditions
### What changes were proposed in this pull request?
This PR adds descriptive errors for cases when MERGE statements that are
converted into executable plans by Spark contain unsupported expressions such
as non-deterministic functions, subqueries, and aggregates.
### Why are the changes needed?
These changes are needed as the current version of the row-level framework
does not support these operations. Keep in mind the new validation only applies
to MERGE statements that are rewritten by `RewriteMergeIntoTable`. Data sources
that inject their own implementation are NOT affected.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
This PR comes with tests.
Closes #41694 from aokolnychyi/spark-44138.
Authored-by: aokolnychyi <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
---
core/src/main/resources/error/error-classes.json | 22 +++++
.../catalyst/analysis/RewriteMergeIntoTable.scala | 34 ++++++-
.../spark/sql/errors/QueryCompilationErrors.scala | 24 +++++
.../sql/connector/MergeIntoTableSuiteBase.scala | 105 ++++++++++++++++++++-
4 files changed, 183 insertions(+), 2 deletions(-)
diff --git a/core/src/main/resources/error/error-classes.json
b/core/src/main/resources/error/error-classes.json
index 264d9b7c3a0..78b54d5230d 100644
--- a/core/src/main/resources/error/error-classes.json
+++ b/core/src/main/resources/error/error-classes.json
@@ -2576,6 +2576,28 @@
}
}
},
+ "UNSUPPORTED_MERGE_CONDITION" : {
+ "message" : [
+ "MERGE operation contains unsupported <condName> condition."
+ ],
+ "subClass" : {
+ "AGGREGATE" : {
+ "message" : [
+ "Aggregates are not allowed: <cond>."
+ ]
+ },
+ "NON_DETERMINISTIC" : {
+ "message" : [
+ "Non-deterministic expressions are not allowed: <cond>."
+ ]
+ },
+ "SUBQUERY" : {
+ "message" : [
+ "Subqueries are not allowed: <cond>."
+ ]
+ }
+ }
+ },
"UNSUPPORTED_OVERWRITE" : {
"message" : [
"Can't overwrite the target that is also being read from."
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala
index 7bfc476d29a..4ba33f4743e 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala
@@ -18,8 +18,9 @@
package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute,
AttributeReference, Exists, Expression, IsNotNull, Literal, MetadataAttribute,
MonotonicallyIncreasingID, OuterReference, PredicateHelper}
+import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute,
AttributeReference, Exists, Expression, IsNotNull, Literal, MetadataAttribute,
MonotonicallyIncreasingID, OuterReference, PredicateHelper, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral,
TrueLiteral}
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, JoinType,
LeftAnti, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, DeleteAction,
Filter, HintInfo, InsertAction, Join, JoinHint, LogicalPlan, MergeAction,
MergeIntoTable, MergeRows, NO_BROADCAST_AND_REPLICATION, Project, ReplaceData,
UpdateAction, WriteDelta}
import org.apache.spark.sql.catalyst.plans.logical.MergeRows.{Discard,
Instruction, Keep, ROW_ID, Split}
@@ -27,6 +28,7 @@ import
org.apache.spark.sql.catalyst.util.RowDeltaUtils.OPERATION_COLUMN
import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations
import org.apache.spark.sql.connector.write.{RowLevelOperationTable,
SupportsDelta}
import org.apache.spark.sql.connector.write.RowLevelOperation.Command.MERGE
+import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
@@ -49,6 +51,8 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand
with PredicateHelper
EliminateSubqueryAliases(aliasedTable) match {
case r: DataSourceV2Relation =>
+ validateMergeIntoConditions(m)
+
// NOT MATCHED conditions may only refer to columns in source so
they can be pushed down
val insertAction = notMatchedActions.head.asInstanceOf[InsertAction]
val filteredSource = insertAction.condition match {
@@ -80,6 +84,8 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand
with PredicateHelper
EliminateSubqueryAliases(aliasedTable) match {
case r: DataSourceV2Relation =>
+ validateMergeIntoConditions(m)
+
// there are only NOT MATCHED actions, use a left anti join to
remove any matching rows
// and switch to using a regular append instead of a row-level MERGE
operation
// only unmatched source rows that match action conditions are
appended to the table
@@ -116,6 +122,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand
with PredicateHelper
EliminateSubqueryAliases(aliasedTable) match {
case r @ DataSourceV2Relation(tbl: SupportsRowLevelOperations, _, _,
_, _) =>
+ validateMergeIntoConditions(m)
val table = buildOperationTable(tbl, MERGE,
CaseInsensitiveStringMap.empty())
table.operation match {
case _: SupportsDelta =>
@@ -468,4 +475,29 @@ object RewriteMergeIntoTable extends
RewriteRowLevelCommand with PredicateHelper
throw new AnalysisException(s"Unexpected action: $other")
}
}
+
+ private def validateMergeIntoConditions(merge: MergeIntoTable): Unit = {
+ checkMergeIntoCondition("SEARCH", merge.mergeCondition)
+ val actions = merge.matchedActions ++ merge.notMatchedActions ++
merge.notMatchedBySourceActions
+ actions.foreach {
+ case DeleteAction(Some(cond)) => checkMergeIntoCondition("DELETE", cond)
+ case UpdateAction(Some(cond), _) => checkMergeIntoCondition("UPDATE",
cond)
+ case InsertAction(Some(cond), _) => checkMergeIntoCondition("INSERT",
cond)
+ case _ => // OK
+ }
+ }
+
+ private def checkMergeIntoCondition(condName: String, cond: Expression):
Unit = {
+ if (!cond.deterministic) {
+ throw QueryCompilationErrors.nonDeterministicMergeCondition(condName,
cond)
+ }
+
+ if (SubqueryExpression.hasSubquery(cond)) {
+ throw
QueryCompilationErrors.subqueryNotAllowedInMergeCondition(condName, cond)
+ }
+
+ if (cond.exists(_.isInstanceOf[AggregateExpression])) {
+ throw
QueryCompilationErrors.aggregationNotAllowedInMergeCondition(condName, cond)
+ }
+ }
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
index 91ebc12b5cd..ecae9778fd0 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
@@ -3186,6 +3186,30 @@ private[sql] object QueryCompilationErrors extends
QueryErrorsBase {
messageParameters = Map.empty)
}
+ def nonDeterministicMergeCondition(condName: String, cond: Expression):
Throwable = {
+ new AnalysisException(
+ errorClass = "UNSUPPORTED_MERGE_CONDITION.NON_DETERMINISTIC",
+ messageParameters = Map(
+ "condName" -> condName,
+ "cond" -> toSQLExpr(cond)))
+ }
+
+ def subqueryNotAllowedInMergeCondition(condName: String, cond: Expression):
Throwable = {
+ new AnalysisException(
+ errorClass = "UNSUPPORTED_MERGE_CONDITION.SUBQUERY",
+ messageParameters = Map(
+ "condName" -> condName,
+ "cond" -> toSQLExpr(cond)))
+ }
+
+ def aggregationNotAllowedInMergeCondition(condName: String, cond:
Expression): Throwable = {
+ new AnalysisException(
+ errorClass = "UNSUPPORTED_MERGE_CONDITION.AGGREGATE",
+ messageParameters = Map(
+ "condName" -> condName,
+ "cond" -> toSQLExpr(cond)))
+ }
+
def failedToParseExistenceDefaultAsLiteral(fieldName: String, defaultValue:
String): Throwable = {
new AnalysisException(
errorClass = "_LEGACY_ERROR_TEMP_1344",
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala
b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala
index 575cd29c993..bd641b2026b 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.connector
import org.apache.spark.SparkException
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.{AnalysisException, Row}
import org.apache.spark.sql.catalyst.optimizer.BuildLeft
import org.apache.spark.sql.connector.catalog.{Column, ColumnDefaultValue}
import org.apache.spark.sql.connector.expressions.LiteralValue
@@ -1319,6 +1319,109 @@ abstract class MergeIntoTableSuiteBase extends
RowLevelOperationSuiteBase {
}
}
+ test("unsupported merge into conditions") {
+ withTempView("source") {
+ createTable("pk INT NOT NULL, salary INT, dep STRING")
+
+ val sourceRows = Seq(
+ (1, 100, "hr"),
+ (2, 200, "finance"),
+ (3, 300, "hr"))
+ sourceRows.toDF("pk", "salary", "dep").createOrReplaceTempView("source")
+
+ val unsupportedSourceExprs = Map(
+ "s.pk < rand()" -> "Non-deterministic expressions are not allowed",
+ "max(s.pk) < 10" -> "Aggregates are not allowed",
+ s"s.pk IN (SELECT pk FROM $tableNameAsString)" -> "Subqueries are not
allowed")
+
+ unsupportedSourceExprs.map { case (expr, errMsg) =>
+ val e1 = intercept[AnalysisException] {
+ sql(
+ s"""MERGE INTO $tableNameAsString t
+ |USING source s
+ |ON t.pk = s.pk AND $expr
+ |WHEN MATCHED THEN
+ | UPDATE SET *
+ |""".stripMargin)
+ }
+ assert(e1.message.contains("unsupported SEARCH condition") &&
e1.message.contains(errMsg))
+
+ val e2 = intercept[AnalysisException] {
+ sql(
+ s"""MERGE INTO $tableNameAsString t
+ |USING source s
+ |ON t.pk = s.pk
+ |WHEN MATCHED AND $expr THEN
+ | UPDATE SET *
+ |""".stripMargin)
+ }
+ assert(e2.message.contains("unsupported UPDATE condition") &&
e2.message.contains(errMsg))
+
+ val e3 = intercept[AnalysisException] {
+ sql(
+ s"""MERGE INTO $tableNameAsString t
+ |USING source s
+ |ON t.pk = s.pk
+ |WHEN MATCHED AND $expr THEN
+ | DELETE
+ |""".stripMargin)
+ }
+ assert(e3.message.contains("unsupported DELETE condition") &&
e3.message.contains(errMsg))
+
+ val e4 = intercept[AnalysisException] {
+ sql(
+ s"""MERGE INTO $tableNameAsString t
+ |USING source s
+ |ON t.pk = s.pk
+ |WHEN NOT MATCHED AND $expr THEN
+ | INSERT *
+ |""".stripMargin)
+ }
+ assert(e4.message.contains("unsupported INSERT condition") &&
e4.message.contains(errMsg))
+ }
+
+ val unsupportedTargetExprs = Map(
+ "t.pk < rand()" -> "Non-deterministic expressions are not allowed",
+ "max(t.pk) < 10" -> "Aggregates are not allowed",
+ s"t.pk IN (SELECT pk FROM $tableNameAsString)" -> "Subqueries are not
allowed")
+
+ unsupportedTargetExprs.map { case (expr, errMsg) =>
+ val e1 = intercept[AnalysisException] {
+ sql(
+ s"""MERGE INTO $tableNameAsString t
+ |USING source s
+ |ON t.pk = s.pk AND $expr
+ |WHEN MATCHED THEN
+ | UPDATE SET *
+ |""".stripMargin)
+ }
+ assert(e1.message.contains("unsupported SEARCH condition") &&
e1.message.contains(errMsg))
+
+ val e2 = intercept[AnalysisException] {
+ sql(
+ s"""MERGE INTO $tableNameAsString t
+ |USING source s
+ |ON t.pk = s.pk
+ |WHEN NOT MATCHED BY SOURCE AND $expr THEN
+ | UPDATE SET t.pk = -1
+ |""".stripMargin)
+ }
+ assert(e2.message.contains("unsupported UPDATE condition") &&
e2.message.contains(errMsg))
+
+ val e3 = intercept[AnalysisException] {
+ sql(
+ s"""MERGE INTO $tableNameAsString t
+ |USING source s
+ |ON t.pk = s.pk
+ |WHEN NOT MATCHED BY SOURCE AND $expr THEN
+ | DELETE
+ |""".stripMargin)
+ }
+ assert(e3.message.contains("unsupported DELETE condition") &&
e3.message.contains(errMsg))
+ }
+ }
+ }
+
private def assertNoLeftBroadcastOrReplication(query: String): Unit = {
val plan = executeAndKeepPlan {
sql(query)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]