This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d3c0f2f30b8f [SPARK-52033][SQL] Fix Generate node bug fix where the 
output of the child node can have multiple copies of the same Attribute
d3c0f2f30b8f is described below

commit d3c0f2f30b8f056fde699db756c8b5798a7bcb1d
Author: Harsh Motwani <harsh.motw...@databricks.com>
AuthorDate: Fri May 9 10:46:38 2025 +0800

    [SPARK-52033][SQL] Fix Generate node bug fix where the output of the child 
node can have multiple copies of the same Attribute
    
    ### What changes were proposed in this pull request?
    
    Instead of checking `requiredAttrSet.contains(attr)`, this PR proposes that 
we also verify the count of these attributes.
    
    ### Why are the changes needed?
    
    Sometimes, one of the child's attribute's could be more frequent in the 
input than the output but all the copies would pass through resulting in an 
assertion failure in Codegen. In interpreted mode, we see incorrect results. 
This is a repro which fails in codegen but returns `{null, null}` as opposed to 
`{1, 1}` in interpreted mode:
    
    ```
    sql("""create or replace temporary function spark_func (params 
array<struct<x int, y int>>)
                | returns STRUCT<a: int, b: int> LANGUAGE SQL
                | return (select ns from (
                | SELECT try_divide(SUM(item.x * item.y), SUM(item.x * item.x)) 
AS beta1,
                | NAMED_STRUCT('a', beta1,'b', beta1) ns
                | FROM (SELECT params) LATERAL VIEW EXPLODE(params) AS item 
LIMIT 1));""".stripMargin)
    
    sql("""select spark_func(collect_list(NAMED_STRUCT('x', 1, 'y', 1))) as 
result;""").collect()
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, described above.
    
    ### How was this patch tested?
    
    Unit tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #50823 from harshmotw-db/harsh-motwani_data/generate_udf_fix.
    
    Lead-authored-by: Harsh Motwani <harsh.motw...@databricks.com>
    Co-authored-by: Wenchen Fan <wenc...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../spark/sql/DataFrameTableValuedFunctionsSuite.scala | 18 ++++++++++++++++++
 .../org/apache/spark/sql/execution/GenerateExec.scala  | 14 ++++++++------
 .../spark/sql/DataFrameTableValuedFunctionsSuite.scala | 18 ++++++++++++++++++
 3 files changed, 44 insertions(+), 6 deletions(-)

diff --git 
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala
 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala
index d11f276e8ed4..e619279940d2 100644
--- 
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala
+++ 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala
@@ -523,4 +523,22 @@ class DataFrameTableValuedFunctionsSuite extends QueryTest 
with RemoteSparkSessi
         sql("SELECT t1.id, t.* FROM variant_table AS t1, LATERAL 
variant_explode_outer(v) AS t"))
     }
   }
+
+  test("explode with udf") {
+    Seq("NO_CODEGEN", "CODEGEN_ONLY").foreach { codegenMode =>
+      withSQLConf("spark.sql.codegen.factoryMode" -> codegenMode) {
+        sql(
+          """create or replace temporary function spark_func (params 
array<struct<x int, y int>>)
+            | returns STRUCT<a: int, b: int> LANGUAGE SQL
+            | return (select ns from (
+            | SELECT try_divide(SUM(item.x * item.y), SUM(item.x * item.x)) AS 
beta1,
+            | NAMED_STRUCT('a', beta1,'b', beta1) ns
+            | FROM (SELECT params) LATERAL VIEW EXPLODE(params) AS item LIMIT 
1));""".stripMargin)
+        val expected = Seq(Row(Row(1, 1)))
+        val actual =
+          sql("""select spark_func(collect_list(NAMED_STRUCT('x', 1, 'y', 1))) 
as result;""")
+        checkAnswer(actual, expected)
+      }
+    }
+  }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
index d6a46d47c104..b5a9d38042df 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
@@ -86,12 +86,16 @@ case class GenerateExec(
       val generatorNullRow = new 
GenericInternalRow(generator.elementSchema.length)
       val rows = if (requiredChildOutput.nonEmpty) {
 
-        val pruneChildForResult: InternalRow => InternalRow =
-          if (child.outputSet == AttributeSet(requiredChildOutput)) {
+        val pruneChildForResult: InternalRow => InternalRow = {
+          // The declared output of this operator is `requiredChildOutput ++ 
generatorOutput`.
+          // If `child.output` is different from `requiredChildOutput`, we 
must do an projection
+          // to adjust the child output and make sure the final result matches 
the declared output.
+          if (child.output == requiredChildOutput) {
             identity
           } else {
             UnsafeProjection.create(requiredChildOutput, child.output)
           }
+        }
 
         val joinedRow = new JoinedRow
         iter.flatMap { row =>
@@ -142,10 +146,8 @@ case class GenerateExec(
   override def needCopyResult: Boolean = true
 
   override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: 
ExprCode): String = {
-    val requiredAttrSet = AttributeSet(requiredChildOutput)
-    val requiredInput = child.output.zip(input).filter {
-      case (attr, _) => requiredAttrSet.contains(attr)
-    }.map(_._2)
+    val attrToInputCode = AttributeMap(child.output.zip(input))
+    val requiredInput = requiredChildOutput.map(attrToInputCode)
     boundGenerator match {
       case e: CollectionGenerator => codeGenCollection(ctx, e, requiredInput)
       case g => codeGenIterableOnce(ctx, g, requiredInput)
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala
index 637e0cf964fe..ad7c297c6f9c 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql
 
 import org.apache.spark.sql.functions._
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSparkSession
 
 class DataFrameTableValuedFunctionsSuite extends QueryTest with 
SharedSparkSession {
@@ -526,4 +527,21 @@ class DataFrameTableValuedFunctionsSuite extends QueryTest 
with SharedSparkSessi
       )
     }
   }
+
+  test("explode with udf") {
+    Seq("NO_CODEGEN", "CODEGEN_ONLY").foreach { codegenMode =>
+      withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenMode)  {
+        sql("""create or replace temporary function spark_func (params 
array<struct<x int, y int>>)
+            | returns STRUCT<a: int, b: int> LANGUAGE SQL
+            | return (select ns from (
+            | SELECT try_divide(SUM(item.x * item.y), SUM(item.x * item.x)) AS 
beta1,
+            | NAMED_STRUCT('a', beta1,'b', beta1) ns
+            | FROM (SELECT params) LATERAL VIEW EXPLODE(params) AS item LIMIT 
1));""".stripMargin)
+        val expected = Seq(Row(Row(1, 1)))
+        val actual =
+          sql("""select spark_func(collect_list(NAMED_STRUCT('x', 1, 'y', 1))) 
as result;""")
+        checkAnswer(actual, expected)
+      }
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to