This is an automated email from the ASF dual-hosted git repository.

ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new dd2ee2400e85 [SPARK-52934][PYTHON] Allow yielding scalar values with 
Arrow-optimized Python UDTF
dd2ee2400e85 is described below

commit dd2ee2400e85dc25ce1de12851b27704ed735417
Author: Takuya Ueshin <ues...@databricks.com>
AuthorDate: Thu Jul 24 12:04:17 2025 -0700

    [SPARK-52934][PYTHON] Allow yielding scalar values with Arrow-optimized 
Python UDTF
    
    ### What changes were proposed in this pull request?
    
    Allows yielding scalar values with Arrow-optimized Python UDTF.
    
    Also the error class for the case where the number of columns is different 
from the return schema is fixed.
    
    ### Why are the changes needed?
    
    There is a behavior difference in Arrow-optimized Python UDTF between 
legacy and new serialization.
    
    - legacy allows yielding scalar values
    - new one doesn't
    
    ```py
    udtf(returnType="a: int")
    class TestUDTF:
        def eval(self, a: int):
            yield a
    ```
    
    The behavior should be consistent.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, the new code path allows yielding scalar values.
    
    ### How was this patch tested?
    
    Updated the related tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #51640 from ueshin/issues/SPARK-52934/yield_scalar.
    
    Authored-by: Takuya Ueshin <ues...@databricks.com>
    Signed-off-by: Takuya Ueshin <ues...@databricks.com>
---
 python/pyspark/sql/tests/test_udtf.py | 68 +++++++++++++++++++----------------
 python/pyspark/worker.py              | 44 ++++++++++++++++-------
 2 files changed, 69 insertions(+), 43 deletions(-)

diff --git a/python/pyspark/sql/tests/test_udtf.py 
b/python/pyspark/sql/tests/test_udtf.py
index b5536ddc7b5d..f7fd8fe6c5b7 100644
--- a/python/pyspark/sql/tests/test_udtf.py
+++ b/python/pyspark/sql/tests/test_udtf.py
@@ -202,6 +202,30 @@ class BaseUDTFTestsMixin:
         with self.assertRaisesRegex(PythonException, 
"UDTF_INVALID_OUTPUT_ROW_TYPE"):
             TestUDTF(lit(1)).collect()
 
+        @udtf(returnType="a: int")
+        class TestUDTF:
+            def eval(self, a: int):
+                return [a]
+
+        with self.assertRaisesRegex(PythonException, 
"UDTF_INVALID_OUTPUT_ROW_TYPE"):
+            TestUDTF(lit(1)).collect()
+
+    def test_udtf_eval_returning_tuple_with_struct_type(self):
+        @udtf(returnType="a: struct<b: int, c: int>")
+        class TestUDTF:
+            def eval(self, a: int):
+                yield (a, a + 1),
+
+        assertDataFrameEqual(TestUDTF(lit(1)), [Row(a=Row(b=1, c=2))])
+
+        @udtf(returnType="a: struct<b: int, c: int>")
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a, a + 1
+
+        with self.assertRaisesRegex(PythonException, 
"UDTF_RETURN_SCHEMA_MISMATCH"):
+            TestUDTF(lit(1)).collect()
+
     def test_udtf_with_invalid_return_value(self):
         @udtf(returnType="x: int")
         class TestUDTF:
@@ -351,15 +375,13 @@ class BaseUDTFTestsMixin:
             TestUDTF(lit(1)).show()
 
     def test_udtf_with_wrong_num_output(self):
-        err_msg = 
"(UDTF_ARROW_TYPE_CONVERSION_ERROR|UDTF_RETURN_SCHEMA_MISMATCH)"
-
         # Output less columns than specified return schema
         @udtf(returnType="a: int, b: int")
         class TestUDTF:
             def eval(self, a: int):
                 yield a,
 
-        with self.assertRaisesRegex(PythonException, err_msg):
+        with self.assertRaisesRegex(PythonException, 
"UDTF_RETURN_SCHEMA_MISMATCH"):
             TestUDTF(lit(1)).collect()
 
         # Output more columns than specified return schema
@@ -368,7 +390,7 @@ class BaseUDTFTestsMixin:
             def eval(self, a: int):
                 yield a, a + 1
 
-        with self.assertRaisesRegex(PythonException, err_msg):
+        with self.assertRaisesRegex(PythonException, 
"UDTF_RETURN_SCHEMA_MISMATCH"):
             TestUDTF(lit(1)).collect()
 
     def test_udtf_with_empty_output_schema_and_non_empty_output(self):
@@ -377,9 +399,7 @@ class BaseUDTFTestsMixin:
             def eval(self):
                 yield 1,
 
-        with self.assertRaisesRegex(
-            PythonException, 
"(UDTF_RETURN_SCHEMA_MISMATCH|UDTF_ARROW_TYPE_CONVERSION_ERROR)"
-        ):
+        with self.assertRaisesRegex(PythonException, 
"UDTF_RETURN_SCHEMA_MISMATCH"):
             TestUDTF().collect()
 
     def test_udtf_with_non_empty_output_schema_and_empty_output(self):
@@ -388,9 +408,7 @@ class BaseUDTFTestsMixin:
             def eval(self):
                 yield tuple()
 
-        with self.assertRaisesRegex(
-            PythonException, 
"(UDTF_RETURN_SCHEMA_MISMATCH|UDTF_ARROW_TYPE_CONVERSION_ERROR)"
-        ):
+        with self.assertRaisesRegex(PythonException, 
"UDTF_RETURN_SCHEMA_MISMATCH"):
             TestUDTF().collect()
 
     def test_udtf_init(self):
@@ -545,8 +563,6 @@ class BaseUDTFTestsMixin:
             TestUDTF(lit(1)).collect()
 
     def test_udtf_terminate_with_wrong_num_output(self):
-        err_msg = 
"(UDTF_RETURN_SCHEMA_MISMATCH|UDTF_ARROW_TYPE_CONVERSION_ERROR)"
-
         @udtf(returnType="a: int, b: int")
         class TestUDTF:
             def eval(self, a: int):
@@ -555,7 +571,7 @@ class BaseUDTFTestsMixin:
             def terminate(self):
                 yield 1, 2, 3
 
-        with self.assertRaisesRegex(PythonException, err_msg):
+        with self.assertRaisesRegex(PythonException, 
"UDTF_RETURN_SCHEMA_MISMATCH"):
             TestUDTF(lit(1)).show()
 
         @udtf(returnType="a: int, b: int")
@@ -566,7 +582,7 @@ class BaseUDTFTestsMixin:
             def terminate(self):
                 yield 1,
 
-        with self.assertRaisesRegex(PythonException, err_msg):
+        with self.assertRaisesRegex(PythonException, 
"UDTF_RETURN_SCHEMA_MISMATCH"):
             TestUDTF(lit(1)).show()
 
     def test_udtf_determinism(self):
@@ -2905,6 +2921,13 @@ class LegacyUDTFArrowTestsMixin(BaseUDTFTestsMixin):
         # When arrow is enabled, it can handle non-tuple return value.
         assertDataFrameEqual(TestUDTF(lit(1)), [Row(a=1)])
 
+        @udtf(returnType="a: int")
+        class TestUDTF:
+            def eval(self, a: int):
+                return (a,)
+
+        assertDataFrameEqual(TestUDTF(lit(1)), [Row(a=1)])
+
         @udtf(returnType="a: int")
         class TestUDTF:
             def eval(self, a: int):
@@ -3158,23 +3181,6 @@ class LegacyUDTFArrowTests(LegacyUDTFArrowTestsMixin, 
ReusedSQLTestCase):
 
 
 class UDTFArrowTestsMixin(LegacyUDTFArrowTestsMixin):
-    def test_udtf_eval_returning_non_tuple(self):
-        @udtf(returnType="a: int")
-        class TestUDTF:
-            def eval(self, a: int):
-                yield a
-
-        with self.assertRaisesRegex(PythonException, 
"UDTF_ARROW_TYPE_CONVERSION_ERROR"):
-            TestUDTF(lit(1)).collect()
-
-        @udtf(returnType="a: int")
-        class TestUDTF:
-            def eval(self, a: int):
-                return [a]
-
-        with self.assertRaisesRegex(PythonException, 
"UDTF_ARROW_TYPE_CONVERSION_ERROR"):
-            TestUDTF(lit(1)).collect()
-
     def test_numeric_output_type_casting(self):
         class TestUDTF:
             def eval(self):
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index adcb5dcf3588..3c869b3dba90 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -80,7 +80,7 @@ from pyspark.sql.types import (
 )
 from pyspark.util import fail_on_stopiteration, handle_worker_exception
 from pyspark import shuffle
-from pyspark.errors import PySparkRuntimeError, PySparkTypeError
+from pyspark.errors import PySparkRuntimeError, PySparkTypeError, 
PySparkValueError
 from pyspark.worker_util import (
     check_python_version,
     read_command,
@@ -1821,13 +1821,13 @@ def read_udtf(pickleSer, infile, eval_type):
                                 "func": f.__name__,
                             },
                         )
-                    if check_output_row_against_schema is not None:
-                        for row in res:
+                    for row in res:
+                        if not isinstance(row, tuple) and return_type_size == 
1:
+                            row = (row,)
+                        if check_output_row_against_schema is not None:
                             if row is not None:
                                 check_output_row_against_schema(row)
-                            yield row
-                    else:
-                        yield from res
+                        yield row
 
             def convert_to_arrow(data: Iterable):
                 data = list(check_return_value(data))
@@ -1835,11 +1835,8 @@ def read_udtf(pickleSer, infile, eval_type):
                     return [
                         pa.RecordBatch.from_pylist(data, 
schema=pa.schema(list(arrow_return_type)))
                     ]
-                try:
-                    return LocalDataToArrowConversion.convert(
-                        data, return_type, prefers_large_var_types
-                    ).to_batches()
-                except Exception as e:
+
+                def raise_conversion_error(original_exception):
                     raise PySparkRuntimeError(
                         errorClass="UDTF_ARROW_TYPE_CONVERSION_ERROR",
                         messageParameters={
@@ -1847,7 +1844,30 @@ def read_udtf(pickleSer, infile, eval_type):
                             "schema": return_type.simpleString(),
                             "arrow_schema": str(arrow_return_type),
                         },
-                    ) from e
+                    ) from original_exception
+
+                try:
+                    return LocalDataToArrowConversion.convert(
+                        data, return_type, prefers_large_var_types
+                    ).to_batches()
+                except PySparkValueError as e:
+                    if e.getErrorClass() == "AXIS_LENGTH_MISMATCH":
+                        raise PySparkRuntimeError(
+                            errorClass="UDTF_RETURN_SCHEMA_MISMATCH",
+                            messageParameters={
+                                "expected": e.getMessageParameters()[
+                                    "expected_length"
+                                ],  # type: ignore[index]
+                                "actual": e.getMessageParameters()[
+                                    "actual_length"
+                                ],  # type: ignore[index]
+                                "func": f.__name__,
+                            },
+                        ) from e
+                    # Fall through to general conversion error
+                    raise_conversion_error(e)
+                except Exception as e:
+                    raise_conversion_error(e)
 
             def evaluate(*args: pa.ChunkedArray):
                 if len(args) == 0:


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to