This is an automated email from the ASF dual-hosted git repository. ueshin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new dd2ee2400e85 [SPARK-52934][PYTHON] Allow yielding scalar values with Arrow-optimized Python UDTF dd2ee2400e85 is described below commit dd2ee2400e85dc25ce1de12851b27704ed735417 Author: Takuya Ueshin <ues...@databricks.com> AuthorDate: Thu Jul 24 12:04:17 2025 -0700 [SPARK-52934][PYTHON] Allow yielding scalar values with Arrow-optimized Python UDTF ### What changes were proposed in this pull request? Allows yielding scalar values with Arrow-optimized Python UDTF. Also the error class for the case where the number of columns is different from the return schema is fixed. ### Why are the changes needed? There is a behavior difference in Arrow-optimized Python UDTF between legacy and new serialization. - legacy allows yielding scalar values - new one doesn't ```py udtf(returnType="a: int") class TestUDTF: def eval(self, a: int): yield a ``` The behavior should be consistent. ### Does this PR introduce _any_ user-facing change? Yes, the new code path allows yielding scalar values. ### How was this patch tested? Updated the related tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #51640 from ueshin/issues/SPARK-52934/yield_scalar. Authored-by: Takuya Ueshin <ues...@databricks.com> Signed-off-by: Takuya Ueshin <ues...@databricks.com> --- python/pyspark/sql/tests/test_udtf.py | 68 +++++++++++++++++++---------------- python/pyspark/worker.py | 44 ++++++++++++++++------- 2 files changed, 69 insertions(+), 43 deletions(-) diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index b5536ddc7b5d..f7fd8fe6c5b7 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -202,6 +202,30 @@ class BaseUDTFTestsMixin: with self.assertRaisesRegex(PythonException, "UDTF_INVALID_OUTPUT_ROW_TYPE"): TestUDTF(lit(1)).collect() + @udtf(returnType="a: int") + class TestUDTF: + def eval(self, a: int): + return [a] + + with self.assertRaisesRegex(PythonException, "UDTF_INVALID_OUTPUT_ROW_TYPE"): + TestUDTF(lit(1)).collect() + + def test_udtf_eval_returning_tuple_with_struct_type(self): + @udtf(returnType="a: struct<b: int, c: int>") + class TestUDTF: + def eval(self, a: int): + yield (a, a + 1), + + assertDataFrameEqual(TestUDTF(lit(1)), [Row(a=Row(b=1, c=2))]) + + @udtf(returnType="a: struct<b: int, c: int>") + class TestUDTF: + def eval(self, a: int): + yield a, a + 1 + + with self.assertRaisesRegex(PythonException, "UDTF_RETURN_SCHEMA_MISMATCH"): + TestUDTF(lit(1)).collect() + def test_udtf_with_invalid_return_value(self): @udtf(returnType="x: int") class TestUDTF: @@ -351,15 +375,13 @@ class BaseUDTFTestsMixin: TestUDTF(lit(1)).show() def test_udtf_with_wrong_num_output(self): - err_msg = "(UDTF_ARROW_TYPE_CONVERSION_ERROR|UDTF_RETURN_SCHEMA_MISMATCH)" - # Output less columns than specified return schema @udtf(returnType="a: int, b: int") class TestUDTF: def eval(self, a: int): yield a, - with self.assertRaisesRegex(PythonException, err_msg): + with self.assertRaisesRegex(PythonException, "UDTF_RETURN_SCHEMA_MISMATCH"): TestUDTF(lit(1)).collect() # Output more columns than specified return schema @@ -368,7 +390,7 @@ class BaseUDTFTestsMixin: def eval(self, a: int): yield a, a + 1 - with self.assertRaisesRegex(PythonException, err_msg): + with self.assertRaisesRegex(PythonException, "UDTF_RETURN_SCHEMA_MISMATCH"): TestUDTF(lit(1)).collect() def test_udtf_with_empty_output_schema_and_non_empty_output(self): @@ -377,9 +399,7 @@ class BaseUDTFTestsMixin: def eval(self): yield 1, - with self.assertRaisesRegex( - PythonException, "(UDTF_RETURN_SCHEMA_MISMATCH|UDTF_ARROW_TYPE_CONVERSION_ERROR)" - ): + with self.assertRaisesRegex(PythonException, "UDTF_RETURN_SCHEMA_MISMATCH"): TestUDTF().collect() def test_udtf_with_non_empty_output_schema_and_empty_output(self): @@ -388,9 +408,7 @@ class BaseUDTFTestsMixin: def eval(self): yield tuple() - with self.assertRaisesRegex( - PythonException, "(UDTF_RETURN_SCHEMA_MISMATCH|UDTF_ARROW_TYPE_CONVERSION_ERROR)" - ): + with self.assertRaisesRegex(PythonException, "UDTF_RETURN_SCHEMA_MISMATCH"): TestUDTF().collect() def test_udtf_init(self): @@ -545,8 +563,6 @@ class BaseUDTFTestsMixin: TestUDTF(lit(1)).collect() def test_udtf_terminate_with_wrong_num_output(self): - err_msg = "(UDTF_RETURN_SCHEMA_MISMATCH|UDTF_ARROW_TYPE_CONVERSION_ERROR)" - @udtf(returnType="a: int, b: int") class TestUDTF: def eval(self, a: int): @@ -555,7 +571,7 @@ class BaseUDTFTestsMixin: def terminate(self): yield 1, 2, 3 - with self.assertRaisesRegex(PythonException, err_msg): + with self.assertRaisesRegex(PythonException, "UDTF_RETURN_SCHEMA_MISMATCH"): TestUDTF(lit(1)).show() @udtf(returnType="a: int, b: int") @@ -566,7 +582,7 @@ class BaseUDTFTestsMixin: def terminate(self): yield 1, - with self.assertRaisesRegex(PythonException, err_msg): + with self.assertRaisesRegex(PythonException, "UDTF_RETURN_SCHEMA_MISMATCH"): TestUDTF(lit(1)).show() def test_udtf_determinism(self): @@ -2905,6 +2921,13 @@ class LegacyUDTFArrowTestsMixin(BaseUDTFTestsMixin): # When arrow is enabled, it can handle non-tuple return value. assertDataFrameEqual(TestUDTF(lit(1)), [Row(a=1)]) + @udtf(returnType="a: int") + class TestUDTF: + def eval(self, a: int): + return (a,) + + assertDataFrameEqual(TestUDTF(lit(1)), [Row(a=1)]) + @udtf(returnType="a: int") class TestUDTF: def eval(self, a: int): @@ -3158,23 +3181,6 @@ class LegacyUDTFArrowTests(LegacyUDTFArrowTestsMixin, ReusedSQLTestCase): class UDTFArrowTestsMixin(LegacyUDTFArrowTestsMixin): - def test_udtf_eval_returning_non_tuple(self): - @udtf(returnType="a: int") - class TestUDTF: - def eval(self, a: int): - yield a - - with self.assertRaisesRegex(PythonException, "UDTF_ARROW_TYPE_CONVERSION_ERROR"): - TestUDTF(lit(1)).collect() - - @udtf(returnType="a: int") - class TestUDTF: - def eval(self, a: int): - return [a] - - with self.assertRaisesRegex(PythonException, "UDTF_ARROW_TYPE_CONVERSION_ERROR"): - TestUDTF(lit(1)).collect() - def test_numeric_output_type_casting(self): class TestUDTF: def eval(self): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index adcb5dcf3588..3c869b3dba90 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -80,7 +80,7 @@ from pyspark.sql.types import ( ) from pyspark.util import fail_on_stopiteration, handle_worker_exception from pyspark import shuffle -from pyspark.errors import PySparkRuntimeError, PySparkTypeError +from pyspark.errors import PySparkRuntimeError, PySparkTypeError, PySparkValueError from pyspark.worker_util import ( check_python_version, read_command, @@ -1821,13 +1821,13 @@ def read_udtf(pickleSer, infile, eval_type): "func": f.__name__, }, ) - if check_output_row_against_schema is not None: - for row in res: + for row in res: + if not isinstance(row, tuple) and return_type_size == 1: + row = (row,) + if check_output_row_against_schema is not None: if row is not None: check_output_row_against_schema(row) - yield row - else: - yield from res + yield row def convert_to_arrow(data: Iterable): data = list(check_return_value(data)) @@ -1835,11 +1835,8 @@ def read_udtf(pickleSer, infile, eval_type): return [ pa.RecordBatch.from_pylist(data, schema=pa.schema(list(arrow_return_type))) ] - try: - return LocalDataToArrowConversion.convert( - data, return_type, prefers_large_var_types - ).to_batches() - except Exception as e: + + def raise_conversion_error(original_exception): raise PySparkRuntimeError( errorClass="UDTF_ARROW_TYPE_CONVERSION_ERROR", messageParameters={ @@ -1847,7 +1844,30 @@ def read_udtf(pickleSer, infile, eval_type): "schema": return_type.simpleString(), "arrow_schema": str(arrow_return_type), }, - ) from e + ) from original_exception + + try: + return LocalDataToArrowConversion.convert( + data, return_type, prefers_large_var_types + ).to_batches() + except PySparkValueError as e: + if e.getErrorClass() == "AXIS_LENGTH_MISMATCH": + raise PySparkRuntimeError( + errorClass="UDTF_RETURN_SCHEMA_MISMATCH", + messageParameters={ + "expected": e.getMessageParameters()[ + "expected_length" + ], # type: ignore[index] + "actual": e.getMessageParameters()[ + "actual_length" + ], # type: ignore[index] + "func": f.__name__, + }, + ) from e + # Fall through to general conversion error + raise_conversion_error(e) + except Exception as e: + raise_conversion_error(e) def evaluate(*args: pa.ChunkedArray): if len(args) == 0: --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org