This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 57dfc0a21bf2 [SPARK-53039][PYTHON][TESTS] Add unit test for complex 
arrow UDF used in window
57dfc0a21bf2 is described below

commit 57dfc0a21bf2ec54d0a9dcc51b62ccfc26b07e0f
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Thu Jul 31 18:47:02 2025 -0700

    [SPARK-53039][PYTHON][TESTS] Add unit test for complex arrow UDF used in 
window
    
    ### What changes were proposed in this pull request?
    Add unit test for complex arrow UDF used in window
    
    ### Why are the changes needed?
    for test coverage
    
    ### Does this PR introduce _any_ user-facing change?
    no, test-only
    
    ### How was this patch tested?
    ci
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #51746 from zhengruifeng/arrow_complex_win.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Dongjoon Hyun <dongj...@apache.org>
---
 .../sql/tests/arrow/test_arrow_udf_window.py       | 103 +++++++++++++++++++++
 1 file changed, 103 insertions(+)

diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py 
b/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py
index 42281b9caf49..a66ccc0bd717 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py
@@ -549,6 +549,109 @@ class WindowArrowUDFTestsMixin:
                             )
                         ).show()
 
+    def test_complex_window_collect_set(self):
+        import pyarrow as pa
+
+        df = self.spark.createDataFrame([(1, 1), (1, 2), (2, 3), (2, 5), (2, 
3)], ("id", "v"))
+        w = Window.partitionBy("id").orderBy("v")
+
+        @arrow_udf("array<int>")
+        def arrow_collect_set(v: pa.Array) -> pa.Scalar:
+            assert isinstance(v, pa.Array), str(type(v))
+            s = sorted([x.as_py() for x in pa.compute.unique(v)])
+            t = pa.list_(pa.int32())
+            return pa.scalar(value=s, type=t)
+
+        result1 = df.select(
+            arrow_collect_set(df["v"]).over(w).alias("vs"),
+        )
+
+        expected1 = df.select(
+            sf.sort_array(sf.collect_set(df["v"]).over(w)).alias("vs"),
+        )
+
+        self.assertEqual(expected1.collect(), result1.collect())
+
+    def test_complex_window_collect_list(self):
+        import pyarrow as pa
+
+        df = self.spark.createDataFrame([(1, 1), (1, 2), (2, 3), (2, 5), (2, 
3)], ("id", "v"))
+        w = Window.partitionBy("id").orderBy("v")
+
+        @arrow_udf("array<int>")
+        def arrow_collect_list(v: pa.Array) -> pa.Scalar:
+            assert isinstance(v, pa.Array), str(type(v))
+            s = sorted([x.as_py() for x in v])
+            t = pa.list_(pa.int32())
+            return pa.scalar(value=s, type=t)
+
+        result1 = df.select(
+            arrow_collect_list(df["v"]).over(w).alias("vs"),
+        )
+
+        expected1 = df.select(
+            sf.sort_array(sf.collect_list(df["v"]).over(w)).alias("vs"),
+        )
+
+        self.assertEqual(expected1.collect(), result1.collect())
+
+    def test_complex_window_collect_as_map(self):
+        import pyarrow as pa
+
+        df = self.spark.createDataFrame(
+            [(1, 2, 1), (1, 3, 2), (2, 4, 3), (2, 5, 5), (2, 6, 3)], ("id", 
"k", "v")
+        )
+        w = Window.partitionBy("id").orderBy("v")
+
+        @arrow_udf("map<int, int>")
+        def arrow_collect_as_map(id: pa.Array, v: pa.Array) -> pa.Scalar:
+            assert isinstance(id, pa.Array), str(type(id))
+            assert isinstance(v, pa.Array), str(type(v))
+            d = {i: j for i, j in zip(id.to_pylist(), v.to_pylist())}
+            t = pa.map_(pa.int32(), pa.int32())
+            return pa.scalar(value=d, type=t)
+
+        result1 = df.select(
+            arrow_collect_as_map("k", "v").over(w).alias("map"),
+        )
+
+        expected1 = df.select(
+            sf.map_from_arrays(
+                sf.collect_list("k").over(w),
+                sf.collect_list("v").over(w),
+            ).alias("map")
+        )
+
+        self.assertEqual(expected1.collect(), result1.collect())
+
+    def test_complex_window_min_max_struct(self):
+        import pyarrow as pa
+
+        df = self.spark.createDataFrame([(1, 1), (1, 2), (2, 3), (2, 5), (2, 
3)], ("id", "v"))
+        w = Window.partitionBy("id").orderBy("v")
+
+        @arrow_udf("struct<m1: int, m2:int>")
+        def arrow_collect_min_max(id: pa.Array, v: pa.Array) -> pa.Scalar:
+            assert isinstance(id, pa.Array), str(type(id))
+            assert isinstance(v, pa.Array), str(type(v))
+            m1 = pa.compute.min(id)
+            m2 = pa.compute.max(v)
+            t = pa.struct([pa.field("m1", pa.int32()), pa.field("m2", 
pa.int32())])
+            return pa.scalar(value={"m1": m1.as_py(), "m2": m2.as_py()}, 
type=t)
+
+        result1 = df.select(
+            arrow_collect_min_max("id", "v").over(w).alias("struct"),
+        )
+
+        expected1 = df.select(
+            sf.struct(
+                sf.min("id").over(w).alias("m1"),
+                sf.max("v").over(w).alias("m2"),
+            ).alias("struct")
+        )
+
+        self.assertEqual(expected1.collect(), result1.collect())
+
     def test_return_type_coercion(self):
         import pyarrow as pa
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to