This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new cdc89aea9ac6 [SPARK-52961][PYTHON] Fix Arrow-optimized Python UDTF 
with 0-arg eval on lateral join
cdc89aea9ac6 is described below

commit cdc89aea9ac6f7ef7b3ad0aa8a14a68501aaa5dd
Author: Takuya Ueshin <[email protected]>
AuthorDate: Mon Jul 28 09:05:32 2025 +0900

    [SPARK-52961][PYTHON] Fix Arrow-optimized Python UDTF with 0-arg eval on 
lateral join
    
    ### What changes were proposed in this pull request?
    
    Fixes Arrow-optimized Python UDTF with 0-arg eval on lateral join.
    
    ### Why are the changes needed?
    
    The Arrow-optimized Python UDTF with 0-arg returns less rows than expected.
    Both legacy and non-legacy code paths are affected.
    
    ```py
    >>> udtf(returnType="i: int", useArrow=True)
    ... class TestUDTF:
    ...     def eval(self):
    ...         yield 0,
    ...
    >>> spark.range(3, numPartitions=1).lateralJoin(TestUDTF()).show()
    +---+---+
    | id|  i|
    +---+---+
    |  0|  0|
    +---+---+
    ```
    
    It should be:
    
    ```py
    +---+---+
    | id|  i|
    +---+---+
    |  0|  0|
    |  1|  0|
    |  2|  0|
    +---+---+
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, Arrow-optimized Python UDTF with 0-arg eval will work well with 
lateral join.
    
    ### How was this patch tested?
    
    Added the related test.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #51672 from ueshin/issues/SPARK-52961/0-arg.
    
    Authored-by: Takuya Ueshin <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/sql/tests/test_udtf.py | 20 ++++++++++++++++++++
 python/pyspark/worker.py              | 22 ++++++++++++----------
 2 files changed, 32 insertions(+), 10 deletions(-)

diff --git a/python/pyspark/sql/tests/test_udtf.py 
b/python/pyspark/sql/tests/test_udtf.py
index 43ec95c2a076..2bb7c6d1f176 100644
--- a/python/pyspark/sql/tests/test_udtf.py
+++ b/python/pyspark/sql/tests/test_udtf.py
@@ -176,6 +176,26 @@ class BaseUDTFTestsMixin:
             self.spark.sql("SELECT * FROM values (0, 1), (1, 2) t(a, b), 
LATERAL testUDTF(a, b)"),
         )
 
+        @udtf(returnType="a: int")
+        class TestUDTF:
+            def eval(self):
+                yield 1,
+                yield 2,
+
+        self.spark.udtf.register("testUDTF", TestUDTF)
+
+        assertDataFrameEqual(
+            self.spark.range(3, numPartitions=1).lateralJoin(TestUDTF()),
+            [
+                Row(id=0, a=1),
+                Row(id=0, a=2),
+                Row(id=1, a=1),
+                Row(id=1, a=2),
+                Row(id=2, a=1),
+                Row(id=2, a=2),
+            ],
+        )
+
     def test_udtf_eval_with_return_stmt(self):
         class TestUDTF:
             def eval(self, a: int, b: int):
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 342ebc14311f..c5e632770bf2 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -1706,18 +1706,19 @@ def read_udtf(pickleSer, infile, eval_type):
                     else:
                         yield from res
 
-            def evaluate(*args: pd.Series):
+            def evaluate(*args: pd.Series, num_rows=1):
                 if len(args) == 0:
-                    res = func()
-                    yield 
verify_result(pd.DataFrame(check_return_value(res))), arrow_return_type
+                    for _ in range(num_rows):
+                        yield verify_result(
+                            pd.DataFrame(check_return_value(func()))
+                        ), arrow_return_type
                 else:
                     # Create tuples from the input pandas Series, each tuple
                     # represents a row across all Series.
                     row_tuples = zip(*args)
                     for row in row_tuples:
-                        res = func(*row)
                         yield verify_result(
-                            pd.DataFrame(check_return_value(res))
+                            pd.DataFrame(check_return_value(func(*row)))
                         ), arrow_return_type
 
             return evaluate
@@ -1739,7 +1740,7 @@ def read_udtf(pickleSer, infile, eval_type):
                 for a in it:
                     # The eval function yields an iterator. Each element 
produced by this
                     # iterator is a tuple in the form of (pandas.DataFrame, 
arrow_return_type).
-                    yield from eval(*[a[o] for o in args_kwargs_offsets])
+                    yield from eval(*[a[o] for o in args_kwargs_offsets], 
num_rows=len(a[0]))
                 if terminate is not None:
                     yield from terminate()
             except SkipRestOfInputTableException:
@@ -1867,10 +1868,11 @@ def read_udtf(pickleSer, infile, eval_type):
                 except Exception as e:
                     raise_conversion_error(e)
 
-            def evaluate(*args: pa.ChunkedArray):
+            def evaluate(*args: pa.ChunkedArray, num_rows=1):
                 if len(args) == 0:
-                    for batch in 
verify_result(convert_to_arrow(func())).to_batches():
-                        yield batch, arrow_return_type
+                    for _ in range(num_rows):
+                        for batch in 
verify_result(convert_to_arrow(func())).to_batches():
+                            yield batch, arrow_return_type
 
                 else:
                     list_args = list(args)
@@ -1903,7 +1905,7 @@ def read_udtf(pickleSer, infile, eval_type):
                 for a in it:
                     # The eval function yields an iterator. Each element 
produced by this
                     # iterator is a tuple in the form of (pyarrow.RecordBatch, 
arrow_return_type).
-                    yield from eval(*[a[o] for o in args_kwargs_offsets])
+                    yield from eval(*[a[o] for o in args_kwargs_offsets], 
num_rows=a.num_rows)
                 if terminate is not None:
                     yield from terminate()
             except SkipRestOfInputTableException:


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to