This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 23e6274f54da [SPARK-48356][FOLLOW UP][SQL] Improve FOR statement's 
column schema inference
23e6274f54da is described below

commit 23e6274f54daad42859d23c519a163ae3e9e5696
Author: David Milicevic <david.milice...@databricks.com>
AuthorDate: Thu Jun 5 10:11:06 2025 -0700

    [SPARK-48356][FOLLOW UP][SQL] Improve FOR statement's column schema 
inference
    
    ### What changes were proposed in this pull request?
    
    This pull request changes `FOR` statement to infer column schemas from the 
query DataFrame, and no longer implicitly infer column schema in SetVariable. 
This is necessary due to type mismatch errors with complex nested types, e.g. 
`ARRAY<STRUCT<..>>`.
    
    ### Why are the changes needed?
    
    Bug fix for FOR statement.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    New unit test that specifically targets problematic case.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #51053 from davidm-db/for_schema_inference.
    
    Lead-authored-by: David Milicevic <david.milice...@databricks.com>
    Co-authored-by: David Milicevic 
<163021185+davidm...@users.noreply.github.com>
    Co-authored-by: Wenchen Fan <cloud0...@gmail.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../sql/scripting/SqlScriptingExecutionNode.scala  | 15 ++++++++----
 .../sql/scripting/SqlScriptingExecutionSuite.scala |  1 -
 .../scripting/SqlScriptingInterpreterSuite.scala   | 28 ++++++++++++++++++++++
 3 files changed, 38 insertions(+), 6 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
index bf9762b505fb..fa8aaf6d81c2 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
@@ -30,7 +30,7 @@ import 
org.apache.spark.sql.catalyst.plans.logical.ExceptionHandlerType.Exceptio
 import org.apache.spark.sql.catalyst.trees.{Origin, WithOrigin}
 import org.apache.spark.sql.classic.{DataFrame, Dataset, SparkSession}
 import org.apache.spark.sql.errors.SqlScriptingErrors
-import org.apache.spark.sql.types.BooleanType
+import org.apache.spark.sql.types.{BooleanType, DataType}
 
 /**
  * Trait for all SQL scripting execution nodes used during interpretation 
phase.
@@ -997,10 +997,14 @@ class ForStatementExec(
   private var state = ForState.VariableAssignment
 
   private var queryResult: util.Iterator[Row] = _
+  private var queryColumnNameToDataType: Map[String, DataType] = _
   private var isResultCacheValid = false
   private def cachedQueryResult(): util.Iterator[Row] = {
     if (!isResultCacheValid) {
-      queryResult = query.buildDataFrame(session).toLocalIterator()
+      val df = query.buildDataFrame(session)
+      queryResult = df.toLocalIterator()
+      queryColumnNameToDataType = df.schema.fields.map(f => f.name -> 
f.dataType).toMap
+
       query.isExecuted = true
       isResultCacheValid = true
     }
@@ -1063,7 +1067,7 @@ class ForStatementExec(
           val variableInitStatements = row.schema.names.toSeq
             .map { colName => (colName, 
createExpressionFromValue(row.getAs(colName))) }
             .flatMap { case (colName, expr) => Seq(
-              createDeclareVarExec(colName, expr),
+              createDeclareVarExec(colName),
               createSetVarExec(colName, expr)
             ) }
 
@@ -1166,8 +1170,9 @@ class ForStatementExec(
     case _ => Literal(value)
   }
 
-  private def createDeclareVarExec(varName: String, variable: Expression): 
SingleStatementExec = {
-    val defaultExpression = DefaultValueExpression(Literal(null, 
variable.dataType), "null")
+  private def createDeclareVarExec(varName: String): SingleStatementExec = {
+    val defaultExpression = DefaultValueExpression(
+      Literal(null, queryColumnNameToDataType(varName)), "null")
     val declareVariable = CreateVariable(
       UnresolvedIdentifier(Seq(varName)),
       defaultExpression,
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala
index 3c0bb4020419..4e208caf4446 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionSuite.scala
@@ -2720,7 +2720,6 @@ class SqlScriptingExecutionSuite extends QueryTest with 
SharedSparkSession {
         |          SELECT varL3;
         |          SELECT 1/0;
         |        END;
-
         |        SELECT 5;
         |        SELECT 1/0;
         |        SELECT 6;
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala
index 7fb5a02aebe4..85e37d4b2309 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala
@@ -3450,4 +3450,32 @@ class SqlScriptingInterpreterSuite extends QueryTest 
with SharedSparkSession {
       verifySqlScriptResult(commands, expected)
     }
   }
+
+  test("for statement - structs in array have different values") {
+    withTable("t") {
+      val sqlScript =
+        """
+          |BEGIN
+          | CREATE TABLE t(
+          |   array_column ARRAY<STRUCT<id: INT, strCol: STRING, intArrayCol: 
ARRAY<INT>>>
+          | );
+          | INSERT INTO t VALUES
+          |  Array(Struct(1, null, Array(10)),
+          |        Struct(2, "name", Array()));
+          | FOR SELECT * FROM t DO
+          |   SELECT array_column;
+          | END FOR;
+          |END
+          |""".stripMargin
+
+      val expected = Seq(
+        Seq.empty[Row], // create table
+        Seq.empty[Row], // insert
+        Seq.empty[Row], // declare array_column
+        Seq.empty[Row], // set array_column
+        Seq(Row(Seq(Row(1, null, Seq(10)), Row(2, "name", Seq.empty))))
+      )
+      verifySqlScriptResult(sqlScript, expected)
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to