This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new eba31a8de3f [SPARK-41806][SQL] Use AppendData.byName for SQL INSERT
INTO by name for DSV2
eba31a8de3f is described below
commit eba31a8de3fb79f96255a0feb58db19842c9d16d
Author: Allison Portis <[email protected]>
AuthorDate: Fri Jan 6 10:42:16 2023 +0800
[SPARK-41806][SQL] Use AppendData.byName for SQL INSERT INTO by name for
DSV2
### What changes were proposed in this pull request?
Use DSv2 AppendData.byName for INSERT INTO by name instead of reordering
and converting to AppendData.byOrdinal
### Why are the changes needed?
Currently for INSERT INTO by name we reorder the value list and convert it
to INSERT INTO by ordinal. Since DSv2 logical nodes have the `isByName` flag we
don't need to do this. The current approach is limiting in that
- Users must provide the full list of table columns (this limits the
functionality for features like generated columns see
[SPARK-41290](https://issues.apache.org/jira/browse/SPARK-41290))
- It allows ambiguous queries such as `INSERT OVERWRITE t PARTITION (c='1')
(c) VALUES ('2')` where the user provides both the static partition column 'c'
and the column 'c' in the column list. We should check that the static
partition column is not in the column list. See the added test for more
detailed example.
### Does this PR introduce _any_ user-facing change?
For versions 3.3 and below:
```sql
CREATE TABLE t(i STRING, c string) USING PARQUET PARTITIONED BY (c);
INSERT OVERWRITE t PARTITION (c='1') (c) VALUES ('2')
SELECT * FROM t
```
```
+---+---+
| i| c|
+---+---+
| 2| 1|
+---+---+
```
For versions 3.4 and above:
```sql
CREATE TABLE t(i STRING, c string) USING PARQUET PARTITIONED BY (c);
INSERT OVERWRITE t PARTITION (c='1') (c) VALUES ('2')
```
```
AnalysisException: [STATIC_PARTITION_COLUMN_IN_COLUMN_LIST] Static
partition column c is also specified in the column list.
```
### How was this patch tested?
Unit tests are added.
Closes #39334 from allisonport-db/insert-into-by-name.
Authored-by: Allison Portis <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
core/src/main/resources/error/error-classes.json | 5 ++
.../spark/sql/catalyst/analysis/Analyzer.scala | 99 +++++++++++++++++++---
.../spark/sql/errors/QueryCompilationErrors.scala | 6 ++
.../org/apache/spark/sql/SQLInsertTestSuite.scala | 16 +++-
.../spark/sql/connector/DataSourceV2SQLSuite.scala | 96 +++++++++++++++++++++
.../execution/command/PlanResolutionSuite.scala | 30 ++++++-
6 files changed, 239 insertions(+), 13 deletions(-)
diff --git a/core/src/main/resources/error/error-classes.json
b/core/src/main/resources/error/error-classes.json
index 29cafdcc1b6..1d1952dce1b 100644
--- a/core/src/main/resources/error/error-classes.json
+++ b/core/src/main/resources/error/error-classes.json
@@ -1145,6 +1145,11 @@
"Star (*) is not allowed in a select list when GROUP BY an ordinal
position is used."
]
},
+ "STATIC_PARTITION_COLUMN_IN_INSERT_COLUMN_LIST" : {
+ "message" : [
+ "Static partition column <staticName> is also specified in the column
list."
+ ]
+ },
"STREAM_FAILED" : {
"message" : [
"Query [id = <id>, runId = <runId>] terminated with exception: <message>"
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 1ebbfb9a39a..8fff0d41add 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -1291,28 +1291,92 @@ class Analyzer(override val catalogManager:
CatalogManager)
}
}
+ /** Handle INSERT INTO for DSv2 */
object ResolveInsertInto extends Rule[LogicalPlan] {
+
+ /** Add a project to use the table column names for INSERT INTO BY NAME */
+ private def createProjectForByNameQuery(i: InsertIntoStatement):
LogicalPlan = {
+ SchemaUtils.checkColumnNameDuplication(i.userSpecifiedCols, resolver)
+
+ if (i.userSpecifiedCols.size != i.query.output.size) {
+ throw QueryCompilationErrors.writeTableWithMismatchedColumnsError(
+ i.userSpecifiedCols.size, i.query.output.size, i.query)
+ }
+ val projectByName = i.userSpecifiedCols.zip(i.query.output)
+ .map { case (userSpecifiedCol, queryOutputCol) =>
+ val resolvedCol = i.table.resolve(Seq(userSpecifiedCol), resolver)
+ .getOrElse(
+ throw QueryCompilationErrors.unresolvedAttributeError(
+ "UNRESOLVED_COLUMN", userSpecifiedCol,
i.table.output.map(_.name), i.origin))
+ (queryOutputCol.dataType, resolvedCol.dataType) match {
+ case (input: StructType, expected: StructType) =>
+ // Rename inner fields of the input column to pass the by-name
INSERT analysis.
+ Alias(Cast(queryOutputCol, renameFieldsInStruct(input,
expected)), resolvedCol.name)()
+ case _ =>
+ Alias(queryOutputCol, resolvedCol.name)()
+ }
+ }
+ Project(projectByName, i.query)
+ }
+
+ private def renameFieldsInStruct(input: StructType, expected: StructType):
StructType = {
+ if (input.length == expected.length) {
+ val newFields = input.zip(expected).map { case (f1, f2) =>
+ (f1.dataType, f2.dataType) match {
+ case (s1: StructType, s2: StructType) =>
+ f1.copy(name = f2.name, dataType = renameFieldsInStruct(s1, s2))
+ case _ =>
+ f1.copy(name = f2.name)
+ }
+ }
+ StructType(newFields)
+ } else {
+ input
+ }
+ }
+
override def apply(plan: LogicalPlan): LogicalPlan =
plan.resolveOperatorsWithPruning(
AlwaysProcess.fn, ruleId) {
case i @ InsertIntoStatement(r: DataSourceV2Relation, _, _, _, _, _)
- if i.query.resolved && i.userSpecifiedCols.isEmpty =>
+ if i.query.resolved =>
// ifPartitionNotExists is append with validation, but validation is
not supported
if (i.ifPartitionNotExists) {
throw
QueryCompilationErrors.unsupportedIfNotExistsError(r.table.name)
}
+ // Create a project if this is an INSERT INTO BY NAME query.
+ val projectByName = if (i.userSpecifiedCols.nonEmpty) {
+ Some(createProjectForByNameQuery(i))
+ } else {
+ None
+ }
+ val isByName = projectByName.nonEmpty
+
val partCols = partitionColumnNames(r.table)
validatePartitionSpec(partCols, i.partitionSpec)
val staticPartitions =
i.partitionSpec.filter(_._2.isDefined).mapValues(_.get).toMap
- val query = addStaticPartitionColumns(r, i.query, staticPartitions)
+ val query = addStaticPartitionColumns(r,
projectByName.getOrElse(i.query), staticPartitions,
+ isByName)
if (!i.overwrite) {
- AppendData.byPosition(r, query)
+ if (isByName) {
+ AppendData.byName(r, query)
+ } else {
+ AppendData.byPosition(r, query)
+ }
} else if (conf.partitionOverwriteMode ==
PartitionOverwriteMode.DYNAMIC) {
- OverwritePartitionsDynamic.byPosition(r, query)
+ if (isByName) {
+ OverwritePartitionsDynamic.byName(r, query)
+ } else {
+ OverwritePartitionsDynamic.byPosition(r, query)
+ }
} else {
- OverwriteByExpression.byPosition(r, query, staticDeleteExpression(r,
staticPartitions))
+ if (isByName) {
+ OverwriteByExpression.byName(r, query, staticDeleteExpression(r,
staticPartitions))
+ } else {
+ OverwriteByExpression.byPosition(r, query,
staticDeleteExpression(r, staticPartitions))
+ }
}
}
@@ -1343,7 +1407,8 @@ class Analyzer(override val catalogManager:
CatalogManager)
private def addStaticPartitionColumns(
relation: DataSourceV2Relation,
query: LogicalPlan,
- staticPartitions: Map[String, String]): LogicalPlan = {
+ staticPartitions: Map[String, String],
+ isByName: Boolean): LogicalPlan = {
if (staticPartitions.isEmpty) {
query
@@ -1352,13 +1417,23 @@ class Analyzer(override val catalogManager:
CatalogManager)
// add any static value as a literal column
val withStaticPartitionValues = {
// for each static name, find the column name it will replace and
check for unknowns.
- val outputNameToStaticName = staticPartitions.keySet.map(staticName
=>
+ val outputNameToStaticName = staticPartitions.keySet.map {
staticName =>
+ if (isByName) {
+ // If this is INSERT INTO BY NAME, the query output's names will
be the user specified
+ // column names. We need to make sure the static partition
column name doesn't appear
+ // there to catch the following ambiguous query:
+ // INSERT OVERWRITE t PARTITION (c='1') (c) VALUES ('2')
+ if (query.output.find(col => conf.resolver(col.name,
staticName)).nonEmpty) {
+ throw
QueryCompilationErrors.staticPartitionInUserSpecifiedColumnsError(staticName)
+ }
+ }
relation.output.find(col => conf.resolver(col.name, staticName))
match {
case Some(attr) =>
attr.name -> staticName
case _ =>
throw
QueryCompilationErrors.missingStaticPartitionColumn(staticName)
- }).toMap
+ }
+ }.toMap
val queryColumns = query.output.iterator
@@ -3646,11 +3721,15 @@ class Analyzer(override val catalogManager:
CatalogManager)
}
}
+ /**
+ * A special rule to reorder columns for DSv1 when users specify a column
list in INSERT INTO.
+ * DSv2 is handled by [[ResolveInsertInto]] separately.
+ */
object ResolveUserSpecifiedColumns extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan =
plan.resolveOperatorsWithPruning(
AlwaysProcess.fn, ruleId) {
- case i: InsertIntoStatement if i.table.resolved && i.query.resolved &&
- i.userSpecifiedCols.nonEmpty =>
+ case i: InsertIntoStatement if
!i.table.isInstanceOf[DataSourceV2Relation] &&
+ i.table.resolved && i.query.resolved && i.userSpecifiedCols.nonEmpty
=>
val resolved = resolveUserSpecifiedColumns(i)
val projection = addColumnListOnQuery(i.table.output, resolved,
i.query)
i.copy(userSpecifiedCols = Nil, query = projection)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
index 621f3e1ca90..f06444847ad 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
@@ -163,6 +163,12 @@ private[sql] object QueryCompilationErrors extends
QueryErrorsBase {
messageParameters = Map("columnName" -> staticName))
}
+ def staticPartitionInUserSpecifiedColumnsError(staticName: String):
Throwable = {
+ new AnalysisException(
+ errorClass = "STATIC_PARTITION_COLUMN_IN_INSERT_COLUMN_LIST",
+ messageParameters = Map("staticName" -> staticName))
+ }
+
def nestedGeneratorError(trimmedNestedGenerator: Expression): Throwable = {
new AnalysisException(errorClass =
"UNSUPPORTED_GENERATOR.NESTED_IN_EXPRESSIONS",
messageParameters = Map("expression" ->
toSQLExpr(trimmedNestedGenerator)))
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala
index f620c0b4c86..051ac0f3141 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala
@@ -201,8 +201,8 @@ trait SQLInsertTestSuite extends QueryTest with
SQLTestUtils {
}
}
- test("insert with column list - mismatched target table out size after
rewritten query") {
- val v2Msg = "expected 2 columns but found"
+ test("insert with column list - missing columns") {
+ val v2Msg = "Cannot write incompatible data to table 'testcat.t1'"
val cols = Seq("c1", "c2", "c3", "c4")
withTable("t1") {
@@ -369,4 +369,16 @@ class DSV2SQLInsertTestSuite extends SQLInsertTestSuite
with SharedSparkSession
.set("spark.sql.catalog.testcat",
classOf[InMemoryPartitionTableCatalog].getName)
.set(SQLConf.DEFAULT_CATALOG.key, "testcat")
}
+
+ test("static partition column name should not be used in the column list") {
+ withTable("t") {
+ sql(s"CREATE TABLE t(i STRING, c string) USING PARQUET PARTITIONED BY
(c)")
+ checkError(
+ exception = intercept[AnalysisException] {
+ sql("INSERT OVERWRITE t PARTITION (c='1') (c) VALUES ('2')")
+ },
+ errorClass = "STATIC_PARTITION_COLUMN_IN_INSERT_COLUMN_LIST",
+ parameters = Map("staticName" -> "c"))
+ }
+ }
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
index 03b42a760ea..a4b7f762dba 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
@@ -1078,6 +1078,102 @@ class DataSourceV2SQLSuiteV1Filter
}
}
+ test("insertInto: append by name") {
+ import testImplicits._
+ val t1 = "tbl"
+ withTable(t1) {
+ sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format")
+ val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data")
+ sql(s"INSERT INTO $t1(id, data) VALUES(1L, 'a')")
+ // Can be in a different order
+ sql(s"INSERT INTO $t1(data, id) VALUES('b', 2L)")
+ // Can be casted automatically
+ sql(s"INSERT INTO $t1(data, id) VALUES('c', 3)")
+ verifyTable(t1, df)
+ // Missing columns
+ assert(intercept[AnalysisException] {
+ sql(s"INSERT INTO $t1(data) VALUES(4)")
+ }.getMessage.contains("Cannot find data for output column 'id'"))
+ // Duplicate columns
+ checkError(
+ exception = intercept[AnalysisException] {
+ sql(s"INSERT INTO $t1(data, data) VALUES(5)")
+ },
+ errorClass = "COLUMN_ALREADY_EXISTS",
+ parameters = Map("columnName" -> "`data`")
+ )
+ }
+ }
+
+ test("insertInto: overwrite by name") {
+ import testImplicits._
+ val t1 = "tbl"
+ withTable(t1) {
+ sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format")
+ sql(s"INSERT OVERWRITE $t1(id, data) VALUES(1L, 'a')")
+ verifyTable(t1, Seq((1L, "a")).toDF("id", "data"))
+ // Can be in a different order
+ sql(s"INSERT OVERWRITE $t1(data, id) VALUES('b', 2L)")
+ verifyTable(t1, Seq((2L, "b")).toDF("id", "data"))
+ // Can be casted automatically
+ sql(s"INSERT OVERWRITE $t1(data, id) VALUES('c', 3)")
+ verifyTable(t1, Seq((3L, "c")).toDF("id", "data"))
+ // Missing columns
+ assert(intercept[AnalysisException] {
+ sql(s"INSERT OVERWRITE $t1(data) VALUES(4)")
+ }.getMessage.contains("Cannot find data for output column 'id'"))
+ // Duplicate columns
+ checkError(
+ exception = intercept[AnalysisException] {
+ sql(s"INSERT OVERWRITE $t1(data, data) VALUES(5)")
+ },
+ errorClass = "COLUMN_ALREADY_EXISTS",
+ parameters = Map("columnName" -> "`data`")
+ )
+ }
+ }
+
+ dynamicOverwriteTest("insertInto: dynamic overwrite by name") {
+ import testImplicits._
+ val t1 = "tbl"
+ withTable(t1) {
+ sql(s"CREATE TABLE $t1 (id bigint, data string, data2 string) " +
+ s"USING $v2Format PARTITIONED BY (id)")
+ sql(s"INSERT OVERWRITE $t1(id, data, data2) VALUES(1L, 'a', 'b')")
+ verifyTable(t1, Seq((1L, "a", "b")).toDF("id", "data", "data2"))
+ // Can be in a different order
+ sql(s"INSERT OVERWRITE $t1(data, data2, id) VALUES('b', 'd', 2L)")
+ verifyTable(t1, Seq((1L, "a", "b"), (2L, "b", "d")).toDF("id", "data",
"data2"))
+ // Can be casted automatically
+ sql(s"INSERT OVERWRITE $t1(data, data2, id) VALUES('c', 'e', 1)")
+ verifyTable(t1, Seq((1L, "c", "e"), (2L, "b", "d")).toDF("id", "data",
"data2"))
+ // Missing columns
+ assert(intercept[AnalysisException] {
+ sql(s"INSERT OVERWRITE $t1(data, id) VALUES('a', 4)")
+ }.getMessage.contains("Cannot find data for output column 'data2'"))
+ // Duplicate columns
+ checkError(
+ exception = intercept[AnalysisException] {
+ sql(s"INSERT OVERWRITE $t1(data, data) VALUES(5)")
+ },
+ errorClass = "COLUMN_ALREADY_EXISTS",
+ parameters = Map("columnName" -> "`data`")
+ )
+ }
+ }
+
+ test("insertInto: static partition column name should not be used in the
column list") {
+ withTable("t") {
+ sql(s"CREATE TABLE t(i STRING, c string) USING $v2Format PARTITIONED BY
(c)")
+ checkError(
+ exception = intercept[AnalysisException] {
+ sql("INSERT OVERWRITE t PARTITION (c='1') (c) VALUES ('2')")
+ },
+ errorClass = "STATIC_PARTITION_COLUMN_IN_INSERT_COLUMN_LIST",
+ parameters = Map("staticName" -> "c"))
+ }
+ }
+
test("ShowViews: using v1 catalog, db name with multipartIdentifier ('a.b')
is not allowed.") {
checkError(
exception = intercept[AnalysisException] {
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala
index d78317c81a2..00d8101df83 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala
@@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.catalog.{BucketSpec,
CatalogStorageFormat,
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast,
EqualTo, EvalMode, Expression, InSubquery, IntegerLiteral, ListQuery, Literal,
StringLiteral}
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException}
-import org.apache.spark.sql.catalyst.plans.logical.{AlterColumn,
AnalysisOnlyCommand, AppendData, Assignment, CreateTable, CreateTableAsSelect,
DeleteAction, DeleteFromTable, DescribeRelation, DropTable, InsertAction,
InsertIntoStatement, LocalRelation, LogicalPlan, MergeIntoTable,
OneRowRelation, Project, SetTableLocation, SetTableProperties,
ShowTableProperties, SubqueryAlias, UnsetTableProperties, UpdateAction,
UpdateTable}
+import org.apache.spark.sql.catalyst.plans.logical.{AlterColumn,
AnalysisOnlyCommand, AppendData, Assignment, CreateTable, CreateTableAsSelect,
DeleteAction, DeleteFromTable, DescribeRelation, DropTable, InsertAction,
InsertIntoStatement, LocalRelation, LogicalPlan, MergeIntoTable,
OneRowRelation, OverwriteByExpression, OverwritePartitionsDynamic, Project,
SetTableLocation, SetTableProperties, ShowTableProperties, SubqueryAlias,
UnsetTableProperties, UpdateAction, UpdateTable}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns
import org.apache.spark.sql.connector.FakeV2Provider
@@ -42,6 +42,7 @@ import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.execution.datasources.{CreateTable =>
CreateTableV1}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.internal.{HiveSerDe, SQLConf}
+import org.apache.spark.sql.internal.SQLConf.{PARTITION_OVERWRITE_MODE,
PartitionOverwriteMode}
import org.apache.spark.sql.sources.SimpleScanSource
import org.apache.spark.sql.types.{BooleanType, CharType, DoubleType,
IntegerType, LongType, MetadataBuilder, StringType, StructField, StructType}
@@ -1237,6 +1238,33 @@ class PlanResolutionSuite extends AnalysisTest {
}
}
+ test("InsertIntoStatement byName") {
+ val tblName = "testcat.tab1"
+ val insertSql = s"INSERT INTO $tblName(i, s) VALUES (3, 'a')"
+ val insertParsed = parseAndResolve(insertSql)
+ val overwriteSql = s"INSERT OVERWRITE $tblName(i, s) VALUES (3, 'a')"
+ val overwriteParsed = parseAndResolve(overwriteSql)
+ insertParsed match {
+ case AppendData(_: DataSourceV2Relation, _, _, isByName, _, _) =>
+ assert(isByName)
+ case _ => fail("Expected AppendData, but got:\n" +
insertParsed.treeString)
+ }
+ overwriteParsed match {
+ case OverwriteByExpression(_: DataSourceV2Relation, _, _, _, isByName,
_, _) =>
+ assert(isByName)
+ case _ => fail("Expected OverwriteByExpression, but got:\n" +
overwriteParsed.treeString)
+ }
+ withSQLConf(PARTITION_OVERWRITE_MODE.key ->
PartitionOverwriteMode.DYNAMIC.toString) {
+ val dynamicOverwriteParsed = parseAndResolve(overwriteSql)
+ dynamicOverwriteParsed match {
+ case OverwritePartitionsDynamic(_: DataSourceV2Relation, _, _,
isByName, _) =>
+ assert(isByName)
+ case _ =>
+ fail("Expected OverwriteByExpression, but got:\n" +
dynamicOverwriteParsed.treeString)
+ }
+ }
+ }
+
test("alter table: alter column") {
Seq("v1Table" -> true, "v2Table" -> false, "testcat.tab" -> false).foreach
{
case (tblName, useV1Command) =>
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]