This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 23e6274f54da [SPARK-48356][FOLLOW UP][SQL] Improve FOR statement's column schema inference 23e6274f54da is described below commit 23e6274f54daad42859d23c519a163ae3e9e5696 Author: David Milicevic <david.milice...@databricks.com> AuthorDate: Thu Jun 5 10:11:06 2025 -0700 [SPARK-48356][FOLLOW UP][SQL] Improve FOR statement's column schema inference ### What changes were proposed in this pull request? This pull request changes `FOR` statement to infer column schemas from the query DataFrame, and no longer implicitly infer column schema in SetVariable. This is necessary due to type mismatch errors with complex nested types, e.g. `ARRAY<STRUCT<..>>`. ### Why are the changes needed? Bug fix for FOR statement. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New unit test that specifically targets problematic case. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #51053 from davidm-db/for_schema_inference. Lead-authored-by: David Milicevic <david.milice...@databricks.com> Co-authored-by: David Milicevic <163021185+davidm...@users.noreply.github.com> Co-authored-by: Wenchen Fan <cloud0...@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../sql/scripting/SqlScriptingExecutionNode.scala | 15 ++++++++---- .../sql/scripting/SqlScriptingExecutionSuite.scala | 1 - .../scripting/SqlScriptingInterpreterSuite.scala | 28 ++++++++++++++++++++++ 3 files changed, 38 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index bf9762b505fb..fa8aaf6d81c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.plans.logical.ExceptionHandlerType.Exceptio import org.apache.spark.sql.catalyst.trees.{Origin, WithOrigin} import org.apache.spark.sql.classic.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.errors.SqlScriptingErrors -import org.apache.spark.sql.types.BooleanType +import org.apache.spark.sql.types.{BooleanType, DataType} /** * Trait for all SQL scripting execution nodes used during interpretation phase. @@ -997,10 +997,14 @@ class ForStatementExec( private var state = ForState.VariableAssignment private var queryResult: util.Iterator[Row] = _ + private var queryColumnNameToDataType: Map[String, DataType] = _ private var isResultCacheValid = false private def cachedQueryResult(): util.Iterator[Row] = { if (!isResultCacheValid) { - queryResult = query.buildDataFrame(session).toLocalIterator() + val df = query.buildDataFrame(session) + queryResult = df.toLocalIterator() + queryColumnNameToDataType = df.schema.fields.map(f => f.name -> f.dataType).toMap + query.isExecuted = true isResultCacheValid = true } @@ -1063,7 +1067,7 @@ class ForStatementExec( val variableInitStatements = row.schema.names.toSeq .map { colName => (colName, createExpressionFromValue(row.getAs(colName))) } .flatMap { case (colName, expr) => Seq( - createDeclareVarExec(colName, expr), + createDeclareVarExec(colName), createSetVarExec(colName, expr) ) } @@ -1166,8 +1170,9 @@ class ForStatementExec( case _ => Literal(value) } - private def createDeclareVarExec(varName: String, variable: Expression): SingleStatementExec = { - val defaultExpression = DefaultValueExpression(Literal(null, variable.dataType), "null") + private def createDeclareVarExec(varName: String): SingleStatementExec = { + val defaultExpression = DefaultValueExpression( + Literal(null, queryColumnNameToDataType(varName)), "null") val declareVariable = CreateVariable( UnresolvedIdentifier(Seq(varName)), defaultExpression, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala index 3c0bb4020419..4e208caf4446 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala @@ -2720,7 +2720,6 @@ class SqlScriptingExecutionSuite extends QueryTest with SharedSparkSession { | SELECT varL3; | SELECT 1/0; | END; - | SELECT 5; | SELECT 1/0; | SELECT 6; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index 7fb5a02aebe4..85e37d4b2309 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -3450,4 +3450,32 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { verifySqlScriptResult(commands, expected) } } + + test("for statement - structs in array have different values") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t( + | array_column ARRAY<STRUCT<id: INT, strCol: STRING, intArrayCol: ARRAY<INT>>> + | ); + | INSERT INTO t VALUES + | Array(Struct(1, null, Array(10)), + | Struct(2, "name", Array())); + | FOR SELECT * FROM t DO + | SELECT array_column; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // declare array_column + Seq.empty[Row], // set array_column + Seq(Row(Seq(Row(1, null, Seq(10)), Row(2, "name", Seq.empty)))) + ) + verifySqlScriptResult(sqlScript, expected) + } + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org