This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new d3c0f2f30b8f [SPARK-52033][SQL] Fix Generate node bug fix where the output of the child node can have multiple copies of the same Attribute d3c0f2f30b8f is described below commit d3c0f2f30b8f056fde699db756c8b5798a7bcb1d Author: Harsh Motwani <harsh.motw...@databricks.com> AuthorDate: Fri May 9 10:46:38 2025 +0800 [SPARK-52033][SQL] Fix Generate node bug fix where the output of the child node can have multiple copies of the same Attribute ### What changes were proposed in this pull request? Instead of checking `requiredAttrSet.contains(attr)`, this PR proposes that we also verify the count of these attributes. ### Why are the changes needed? Sometimes, one of the child's attribute's could be more frequent in the input than the output but all the copies would pass through resulting in an assertion failure in Codegen. In interpreted mode, we see incorrect results. This is a repro which fails in codegen but returns `{null, null}` as opposed to `{1, 1}` in interpreted mode: ``` sql("""create or replace temporary function spark_func (params array<struct<x int, y int>>) | returns STRUCT<a: int, b: int> LANGUAGE SQL | return (select ns from ( | SELECT try_divide(SUM(item.x * item.y), SUM(item.x * item.x)) AS beta1, | NAMED_STRUCT('a', beta1,'b', beta1) ns | FROM (SELECT params) LATERAL VIEW EXPLODE(params) AS item LIMIT 1));""".stripMargin) sql("""select spark_func(collect_list(NAMED_STRUCT('x', 1, 'y', 1))) as result;""").collect() ``` ### Does this PR introduce _any_ user-facing change? Yes, described above. ### How was this patch tested? Unit tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #50823 from harshmotw-db/harsh-motwani_data/generate_udf_fix. Lead-authored-by: Harsh Motwani <harsh.motw...@databricks.com> Co-authored-by: Wenchen Fan <wenc...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../spark/sql/DataFrameTableValuedFunctionsSuite.scala | 18 ++++++++++++++++++ .../org/apache/spark/sql/execution/GenerateExec.scala | 14 ++++++++------ .../spark/sql/DataFrameTableValuedFunctionsSuite.scala | 18 ++++++++++++++++++ 3 files changed, 44 insertions(+), 6 deletions(-) diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala index d11f276e8ed4..e619279940d2 100644 --- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala @@ -523,4 +523,22 @@ class DataFrameTableValuedFunctionsSuite extends QueryTest with RemoteSparkSessi sql("SELECT t1.id, t.* FROM variant_table AS t1, LATERAL variant_explode_outer(v) AS t")) } } + + test("explode with udf") { + Seq("NO_CODEGEN", "CODEGEN_ONLY").foreach { codegenMode => + withSQLConf("spark.sql.codegen.factoryMode" -> codegenMode) { + sql( + """create or replace temporary function spark_func (params array<struct<x int, y int>>) + | returns STRUCT<a: int, b: int> LANGUAGE SQL + | return (select ns from ( + | SELECT try_divide(SUM(item.x * item.y), SUM(item.x * item.x)) AS beta1, + | NAMED_STRUCT('a', beta1,'b', beta1) ns + | FROM (SELECT params) LATERAL VIEW EXPLODE(params) AS item LIMIT 1));""".stripMargin) + val expected = Seq(Row(Row(1, 1))) + val actual = + sql("""select spark_func(collect_list(NAMED_STRUCT('x', 1, 'y', 1))) as result;""") + checkAnswer(actual, expected) + } + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index d6a46d47c104..b5a9d38042df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -86,12 +86,16 @@ case class GenerateExec( val generatorNullRow = new GenericInternalRow(generator.elementSchema.length) val rows = if (requiredChildOutput.nonEmpty) { - val pruneChildForResult: InternalRow => InternalRow = - if (child.outputSet == AttributeSet(requiredChildOutput)) { + val pruneChildForResult: InternalRow => InternalRow = { + // The declared output of this operator is `requiredChildOutput ++ generatorOutput`. + // If `child.output` is different from `requiredChildOutput`, we must do an projection + // to adjust the child output and make sure the final result matches the declared output. + if (child.output == requiredChildOutput) { identity } else { UnsafeProjection.create(requiredChildOutput, child.output) } + } val joinedRow = new JoinedRow iter.flatMap { row => @@ -142,10 +146,8 @@ case class GenerateExec( override def needCopyResult: Boolean = true override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { - val requiredAttrSet = AttributeSet(requiredChildOutput) - val requiredInput = child.output.zip(input).filter { - case (attr, _) => requiredAttrSet.contains(attr) - }.map(_._2) + val attrToInputCode = AttributeMap(child.output.zip(input)) + val requiredInput = requiredChildOutput.map(attrToInputCode) boundGenerator match { case e: CollectionGenerator => codeGenCollection(ctx, e, requiredInput) case g => codeGenIterableOnce(ctx, g, requiredInput) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala index 637e0cf964fe..ad7c297c6f9c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession class DataFrameTableValuedFunctionsSuite extends QueryTest with SharedSparkSession { @@ -526,4 +527,21 @@ class DataFrameTableValuedFunctionsSuite extends QueryTest with SharedSparkSessi ) } } + + test("explode with udf") { + Seq("NO_CODEGEN", "CODEGEN_ONLY").foreach { codegenMode => + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenMode) { + sql("""create or replace temporary function spark_func (params array<struct<x int, y int>>) + | returns STRUCT<a: int, b: int> LANGUAGE SQL + | return (select ns from ( + | SELECT try_divide(SUM(item.x * item.y), SUM(item.x * item.x)) AS beta1, + | NAMED_STRUCT('a', beta1,'b', beta1) ns + | FROM (SELECT params) LATERAL VIEW EXPLODE(params) AS item LIMIT 1));""".stripMargin) + val expected = Seq(Row(Row(1, 1))) + val actual = + sql("""select spark_func(collect_list(NAMED_STRUCT('x', 1, 'y', 1))) as result;""") + checkAnswer(actual, expected) + } + } + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org