This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new cdc89aea9ac6 [SPARK-52961][PYTHON] Fix Arrow-optimized Python UDTF
with 0-arg eval on lateral join
cdc89aea9ac6 is described below
commit cdc89aea9ac6f7ef7b3ad0aa8a14a68501aaa5dd
Author: Takuya Ueshin <[email protected]>
AuthorDate: Mon Jul 28 09:05:32 2025 +0900
[SPARK-52961][PYTHON] Fix Arrow-optimized Python UDTF with 0-arg eval on
lateral join
### What changes were proposed in this pull request?
Fixes Arrow-optimized Python UDTF with 0-arg eval on lateral join.
### Why are the changes needed?
The Arrow-optimized Python UDTF with 0-arg returns less rows than expected.
Both legacy and non-legacy code paths are affected.
```py
>>> udtf(returnType="i: int", useArrow=True)
... class TestUDTF:
... def eval(self):
... yield 0,
...
>>> spark.range(3, numPartitions=1).lateralJoin(TestUDTF()).show()
+---+---+
| id| i|
+---+---+
| 0| 0|
+---+---+
```
It should be:
```py
+---+---+
| id| i|
+---+---+
| 0| 0|
| 1| 0|
| 2| 0|
+---+---+
```
### Does this PR introduce _any_ user-facing change?
Yes, Arrow-optimized Python UDTF with 0-arg eval will work well with
lateral join.
### How was this patch tested?
Added the related test.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #51672 from ueshin/issues/SPARK-52961/0-arg.
Authored-by: Takuya Ueshin <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/sql/tests/test_udtf.py | 20 ++++++++++++++++++++
python/pyspark/worker.py | 22 ++++++++++++----------
2 files changed, 32 insertions(+), 10 deletions(-)
diff --git a/python/pyspark/sql/tests/test_udtf.py
b/python/pyspark/sql/tests/test_udtf.py
index 43ec95c2a076..2bb7c6d1f176 100644
--- a/python/pyspark/sql/tests/test_udtf.py
+++ b/python/pyspark/sql/tests/test_udtf.py
@@ -176,6 +176,26 @@ class BaseUDTFTestsMixin:
self.spark.sql("SELECT * FROM values (0, 1), (1, 2) t(a, b),
LATERAL testUDTF(a, b)"),
)
+ @udtf(returnType="a: int")
+ class TestUDTF:
+ def eval(self):
+ yield 1,
+ yield 2,
+
+ self.spark.udtf.register("testUDTF", TestUDTF)
+
+ assertDataFrameEqual(
+ self.spark.range(3, numPartitions=1).lateralJoin(TestUDTF()),
+ [
+ Row(id=0, a=1),
+ Row(id=0, a=2),
+ Row(id=1, a=1),
+ Row(id=1, a=2),
+ Row(id=2, a=1),
+ Row(id=2, a=2),
+ ],
+ )
+
def test_udtf_eval_with_return_stmt(self):
class TestUDTF:
def eval(self, a: int, b: int):
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 342ebc14311f..c5e632770bf2 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -1706,18 +1706,19 @@ def read_udtf(pickleSer, infile, eval_type):
else:
yield from res
- def evaluate(*args: pd.Series):
+ def evaluate(*args: pd.Series, num_rows=1):
if len(args) == 0:
- res = func()
- yield
verify_result(pd.DataFrame(check_return_value(res))), arrow_return_type
+ for _ in range(num_rows):
+ yield verify_result(
+ pd.DataFrame(check_return_value(func()))
+ ), arrow_return_type
else:
# Create tuples from the input pandas Series, each tuple
# represents a row across all Series.
row_tuples = zip(*args)
for row in row_tuples:
- res = func(*row)
yield verify_result(
- pd.DataFrame(check_return_value(res))
+ pd.DataFrame(check_return_value(func(*row)))
), arrow_return_type
return evaluate
@@ -1739,7 +1740,7 @@ def read_udtf(pickleSer, infile, eval_type):
for a in it:
# The eval function yields an iterator. Each element
produced by this
# iterator is a tuple in the form of (pandas.DataFrame,
arrow_return_type).
- yield from eval(*[a[o] for o in args_kwargs_offsets])
+ yield from eval(*[a[o] for o in args_kwargs_offsets],
num_rows=len(a[0]))
if terminate is not None:
yield from terminate()
except SkipRestOfInputTableException:
@@ -1867,10 +1868,11 @@ def read_udtf(pickleSer, infile, eval_type):
except Exception as e:
raise_conversion_error(e)
- def evaluate(*args: pa.ChunkedArray):
+ def evaluate(*args: pa.ChunkedArray, num_rows=1):
if len(args) == 0:
- for batch in
verify_result(convert_to_arrow(func())).to_batches():
- yield batch, arrow_return_type
+ for _ in range(num_rows):
+ for batch in
verify_result(convert_to_arrow(func())).to_batches():
+ yield batch, arrow_return_type
else:
list_args = list(args)
@@ -1903,7 +1905,7 @@ def read_udtf(pickleSer, infile, eval_type):
for a in it:
# The eval function yields an iterator. Each element
produced by this
# iterator is a tuple in the form of (pyarrow.RecordBatch,
arrow_return_type).
- yield from eval(*[a[o] for o in args_kwargs_offsets])
+ yield from eval(*[a[o] for o in args_kwargs_offsets],
num_rows=a.num_rows)
if terminate is not None:
yield from terminate()
except SkipRestOfInputTableException:
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]