This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 9ae819840bbf [SPARK-53611][PYTHON] Limit Arrow batch sizes in window 
agg UDFs
9ae819840bbf is described below

commit 9ae819840bbf2f6c8fc8e7d978c89f3d11b57b05
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Oct 15 09:59:20 2025 +0800

    [SPARK-53611][PYTHON] Limit Arrow batch sizes in window agg UDFs
    
    ### What changes were proposed in this pull request?
    Limit Arrow batch sizes in window agg UDFs
    
    ### Why are the changes needed?
    to avoid OOM in the JVM side, by batching the JVM->Python Arrow Batches
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    New tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #52608 from zhengruifeng/limit_win_agg.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 python/pyspark/sql/pandas/serializers.py           |   5 +-
 .../sql/tests/arrow/test_arrow_udf_window.py       |  46 +++++++
 .../sql/tests/pandas/test_pandas_udf_window.py     | 141 ++++++++++++---------
 python/pyspark/worker.py                           |  11 +-
 .../python/ArrowWindowPythonEvaluatorFactory.scala |   6 +-
 5 files changed, 143 insertions(+), 66 deletions(-)

diff --git a/python/pyspark/sql/pandas/serializers.py 
b/python/pyspark/sql/pandas/serializers.py
index 323aea3a59ce..89b30668424a 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -1143,7 +1143,8 @@ class 
GroupArrowUDFSerializer(ArrowStreamGroupUDFSerializer):
         return "GroupArrowUDFSerializer"
 
 
-class AggArrowUDFSerializer(ArrowStreamArrowUDFSerializer):
+# Serializer for SQL_GROUPED_AGG_ARROW_UDF and SQL_WINDOW_AGG_ARROW_UDF
+class ArrowStreamAggArrowUDFSerializer(ArrowStreamArrowUDFSerializer):
     def __init__(
         self,
         timezone,
@@ -1183,7 +1184,7 @@ class 
AggArrowUDFSerializer(ArrowStreamArrowUDFSerializer):
                 )
 
     def __repr__(self):
-        return "AggArrowUDFSerializer"
+        return "ArrowStreamAggArrowUDFSerializer"
 
 
 class GroupPandasUDFSerializer(ArrowStreamPandasUDFSerializer):
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py 
b/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py
index 1d301597c21a..b3ed4c020ca3 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py
@@ -758,6 +758,52 @@ class WindowArrowUDFTestsMixin:
         )
         self.assertEqual(expected.collect(), result.collect())
 
+    def test_arrow_batch_slicing(self):
+        import pyarrow as pa
+
+        df = self.spark.range(1000).select((sf.col("id") % 2).alias("key"), 
sf.col("id").alias("v"))
+
+        w1 = Window.partitionBy("key").orderBy("v")
+        w2 = (
+            Window.partitionBy("key")
+            .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
+            .orderBy("v")
+        )
+
+        @arrow_udf("long", ArrowUDFType.GROUPED_AGG)
+        def arrow_sum(v):
+            return pa.compute.sum(v)
+
+        @arrow_udf("long", ArrowUDFType.GROUPED_AGG)
+        def arrow_sum_unbounded(v):
+            assert len(v) == 1000 / 2, len(v)
+            return pa.compute.sum(v)
+
+        expected1 = df.select("*", 
sf.sum("v").over(w1).alias("res")).sort("key", "v").collect()
+        expected2 = df.select("*", 
sf.sum("v").over(w2).alias("res")).sort("key", "v").collect()
+
+        for maxRecords, maxBytes in [(10, 2**31 - 1), (0, 64), (10, 64)]:
+            with self.subTest(maxRecords=maxRecords, maxBytes=maxBytes):
+                with self.sql_conf(
+                    {
+                        "spark.sql.execution.arrow.maxRecordsPerBatch": 
maxRecords,
+                        "spark.sql.execution.arrow.maxBytesPerBatch": maxBytes,
+                    }
+                ):
+                    result1 = (
+                        df.select("*", arrow_sum("v").over(w1).alias("res"))
+                        .sort("key", "v")
+                        .collect()
+                    )
+                    self.assertEqual(expected1, result1)
+
+                    result2 = (
+                        df.select("*", 
arrow_sum_unbounded("v").over(w2).alias("res"))
+                        .sort("key", "v")
+                        .collect()
+                    )
+                    self.assertEqual(expected2, result2)
+
 
 class WindowArrowUDFTests(WindowArrowUDFTestsMixin, ReusedSQLTestCase):
     pass
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py 
b/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py
index 2f534b811b34..fbc2b32d1c69 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py
@@ -20,19 +20,8 @@ from typing import cast
 from decimal import Decimal
 
 from pyspark.errors import AnalysisException, PythonException
-from pyspark.sql.functions import (
-    array,
-    explode,
-    col,
-    lit,
-    mean,
-    min,
-    max,
-    rank,
-    udf,
-    pandas_udf,
-    PandasUDFType,
-)
+from pyspark.sql import functions as sf
+from pyspark.sql.functions import udf, pandas_udf, PandasUDFType
 from pyspark.sql.window import Window
 from pyspark.sql.types import (
     DecimalType,
@@ -64,10 +53,10 @@ class WindowPandasUDFTestsMixin:
         return (
             self.spark.range(10)
             .toDF("id")
-            .withColumn("vs", array([lit(i * 1.0) + col("id") for i in 
range(20, 30)]))
-            .withColumn("v", explode(col("vs")))
+            .withColumn("vs", sf.array([sf.lit(i * 1.0) + sf.col("id") for i 
in range(20, 30)]))
+            .withColumn("v", sf.explode(sf.col("vs")))
             .drop("vs")
-            .withColumn("w", lit(1.0))
+            .withColumn("w", sf.lit(1.0))
         )
 
     @property
@@ -172,10 +161,10 @@ class WindowPandasUDFTestsMixin:
         mean_udf = self.pandas_agg_mean_udf
 
         result1 = df.withColumn("mean_v", mean_udf(df["v"]).over(w))
-        expected1 = df.withColumn("mean_v", mean(df["v"]).over(w))
+        expected1 = df.withColumn("mean_v", sf.mean(df["v"]).over(w))
 
         result2 = df.select(mean_udf(df["v"]).over(w))
-        expected2 = df.select(mean(df["v"]).over(w))
+        expected2 = df.select(sf.mean(df["v"]).over(w))
 
         assert_frame_equal(expected1.toPandas(), result1.toPandas())
         assert_frame_equal(expected2.toPandas(), result2.toPandas())
@@ -191,9 +180,9 @@ class WindowPandasUDFTestsMixin:
         )
 
         expected1 = (
-            df.withColumn("mean_v", mean(df["v"]).over(w))
-            .withColumn("max_v", max(df["v"]).over(w))
-            .withColumn("min_w", min(df["w"]).over(w))
+            df.withColumn("mean_v", sf.mean(df["v"]).over(w))
+            .withColumn("max_v", sf.max(df["v"]).over(w))
+            .withColumn("min_w", sf.min(df["w"]).over(w))
         )
 
         assert_frame_equal(expected1.toPandas(), result1.toPandas())
@@ -203,7 +192,7 @@ class WindowPandasUDFTestsMixin:
         w = self.unbounded_window
 
         result1 = df.withColumn("v", self.pandas_agg_mean_udf(df["v"]).over(w))
-        expected1 = df.withColumn("v", mean(df["v"]).over(w))
+        expected1 = df.withColumn("v", sf.mean(df["v"]).over(w))
 
         assert_frame_equal(expected1.toPandas(), result1.toPandas())
 
@@ -213,7 +202,7 @@ class WindowPandasUDFTestsMixin:
         mean_udf = self.pandas_agg_mean_udf
 
         result1 = df.withColumn("v", mean_udf(df["v"] * 2).over(w) + 1)
-        expected1 = df.withColumn("v", mean(df["v"] * 2).over(w) + 1)
+        expected1 = df.withColumn("v", sf.mean(df["v"] * 2).over(w) + 1)
 
         assert_frame_equal(expected1.toPandas(), result1.toPandas())
 
@@ -226,10 +215,10 @@ class WindowPandasUDFTestsMixin:
         mean_udf = self.pandas_agg_mean_udf
 
         result1 = df.withColumn("v2", 
plus_one(mean_udf(plus_one(df["v"])).over(w)))
-        expected1 = df.withColumn("v2", 
plus_one(mean(plus_one(df["v"])).over(w)))
+        expected1 = df.withColumn("v2", 
plus_one(sf.mean(plus_one(df["v"])).over(w)))
 
         result2 = df.withColumn("v2", 
time_two(mean_udf(time_two(df["v"])).over(w)))
-        expected2 = df.withColumn("v2", 
time_two(mean(time_two(df["v"])).over(w)))
+        expected2 = df.withColumn("v2", 
time_two(sf.mean(time_two(df["v"])).over(w)))
 
         assert_frame_equal(expected1.toPandas(), result1.toPandas())
         assert_frame_equal(expected2.toPandas(), result2.toPandas())
@@ -240,10 +229,10 @@ class WindowPandasUDFTestsMixin:
         mean_udf = self.pandas_agg_mean_udf
 
         result1 = df.withColumn("v2", mean_udf(df["v"]).over(w))
-        expected1 = df.withColumn("v2", mean(df["v"]).over(w))
+        expected1 = df.withColumn("v2", sf.mean(df["v"]).over(w))
 
         result2 = df.select(mean_udf(df["v"]).over(w))
-        expected2 = df.select(mean(df["v"]).over(w))
+        expected2 = df.select(sf.mean(df["v"]).over(w))
 
         assert_frame_equal(expected1.toPandas(), result1.toPandas())
         assert_frame_equal(expected2.toPandas(), result2.toPandas())
@@ -256,26 +245,28 @@ class WindowPandasUDFTestsMixin:
         min_udf = self.pandas_agg_min_udf
 
         result1 = df.withColumn("v_diff", max_udf(df["v"]).over(w) - 
min_udf(df["v"]).over(w))
-        expected1 = df.withColumn("v_diff", max(df["v"]).over(w) - 
min(df["v"]).over(w))
+        expected1 = df.withColumn("v_diff", sf.max(df["v"]).over(w) - 
sf.min(df["v"]).over(w))
 
         # Test mixing sql window function and window udf in the same expression
-        result2 = df.withColumn("v_diff", max_udf(df["v"]).over(w) - 
min(df["v"]).over(w))
+        result2 = df.withColumn("v_diff", max_udf(df["v"]).over(w) - 
sf.min(df["v"]).over(w))
         expected2 = expected1
 
         # Test chaining sql aggregate function and udf
         result3 = (
             df.withColumn("max_v", max_udf(df["v"]).over(w))
-            .withColumn("min_v", min(df["v"]).over(w))
-            .withColumn("v_diff", col("max_v") - col("min_v"))
+            .withColumn("min_v", sf.min(df["v"]).over(w))
+            .withColumn("v_diff", sf.col("max_v") - sf.col("min_v"))
             .drop("max_v", "min_v")
         )
         expected3 = expected1
 
         # Test mixing sql window function and udf
         result4 = df.withColumn("max_v", max_udf(df["v"]).over(w)).withColumn(
-            "rank", rank().over(ow)
+            "rank", sf.rank().over(ow)
+        )
+        expected4 = df.withColumn("max_v", sf.max(df["v"]).over(w)).withColumn(
+            "rank", sf.rank().over(ow)
         )
-        expected4 = df.withColumn("max_v", 
max(df["v"]).over(w)).withColumn("rank", rank().over(ow))
 
         assert_frame_equal(expected1.toPandas(), result1.toPandas())
         assert_frame_equal(expected2.toPandas(), result2.toPandas())
@@ -303,8 +294,6 @@ class WindowPandasUDFTestsMixin:
             df.withColumn("v2", foo_udf(df["v"]).over(w)).schema
 
     def test_bounded_simple(self):
-        from pyspark.sql.functions import mean, max, min, count
-
         df = self.data
         w1 = self.sliding_row_window
         w2 = self.shrinking_range_window
@@ -323,17 +312,15 @@ class WindowPandasUDFTestsMixin:
         )
 
         expected1 = (
-            df.withColumn("mean_v", mean(plus_one(df["v"])).over(w1))
-            .withColumn("count_v", count(df["v"]).over(w2))
-            .withColumn("max_v", max(df["v"]).over(w2))
-            .withColumn("min_v", min(df["v"]).over(w1))
+            df.withColumn("mean_v", sf.mean(plus_one(df["v"])).over(w1))
+            .withColumn("count_v", sf.count(df["v"]).over(w2))
+            .withColumn("max_v", sf.max(df["v"]).over(w2))
+            .withColumn("min_v", sf.min(df["v"]).over(w1))
         )
 
         assert_frame_equal(expected1.toPandas(), result1.toPandas())
 
     def test_growing_window(self):
-        from pyspark.sql.functions import mean
-
         df = self.data
         w1 = self.growing_row_window
         w2 = self.growing_range_window
@@ -344,15 +331,13 @@ class WindowPandasUDFTestsMixin:
             "m2", mean_udf(df["v"]).over(w2)
         )
 
-        expected1 = df.withColumn("m1", mean(df["v"]).over(w1)).withColumn(
-            "m2", mean(df["v"]).over(w2)
+        expected1 = df.withColumn("m1", sf.mean(df["v"]).over(w1)).withColumn(
+            "m2", sf.mean(df["v"]).over(w2)
         )
 
         assert_frame_equal(expected1.toPandas(), result1.toPandas())
 
     def test_sliding_window(self):
-        from pyspark.sql.functions import mean
-
         df = self.data
         w1 = self.sliding_row_window
         w2 = self.sliding_range_window
@@ -363,15 +348,13 @@ class WindowPandasUDFTestsMixin:
             "m2", mean_udf(df["v"]).over(w2)
         )
 
-        expected1 = df.withColumn("m1", mean(df["v"]).over(w1)).withColumn(
-            "m2", mean(df["v"]).over(w2)
+        expected1 = df.withColumn("m1", sf.mean(df["v"]).over(w1)).withColumn(
+            "m2", sf.mean(df["v"]).over(w2)
         )
 
         assert_frame_equal(expected1.toPandas(), result1.toPandas())
 
     def test_shrinking_window(self):
-        from pyspark.sql.functions import mean
-
         df = self.data
         w1 = self.shrinking_row_window
         w2 = self.shrinking_range_window
@@ -382,15 +365,13 @@ class WindowPandasUDFTestsMixin:
             "m2", mean_udf(df["v"]).over(w2)
         )
 
-        expected1 = df.withColumn("m1", mean(df["v"]).over(w1)).withColumn(
-            "m2", mean(df["v"]).over(w2)
+        expected1 = df.withColumn("m1", sf.mean(df["v"]).over(w1)).withColumn(
+            "m2", sf.mean(df["v"]).over(w2)
         )
 
         assert_frame_equal(expected1.toPandas(), result1.toPandas())
 
     def test_bounded_mixed(self):
-        from pyspark.sql.functions import mean, max
-
         df = self.data
         w1 = self.sliding_row_window
         w2 = self.unbounded_window
@@ -405,9 +386,9 @@ class WindowPandasUDFTestsMixin:
         )
 
         expected1 = (
-            df.withColumn("mean_v", mean(df["v"]).over(w1))
-            .withColumn("max_v", max(df["v"]).over(w2))
-            .withColumn("mean_unbounded_v", mean(df["v"]).over(w1))
+            df.withColumn("mean_v", sf.mean(df["v"]).over(w1))
+            .withColumn("max_v", sf.max(df["v"]).over(w2))
+            .withColumn("mean_unbounded_v", sf.mean(df["v"]).over(w1))
         )
 
         assert_frame_equal(expected1.toPandas(), result1.toPandas())
@@ -425,7 +406,7 @@ class WindowPandasUDFTestsMixin:
                 ]
             ):
                 with self.subTest(bound=bound, query_no=i):
-                    assertDataFrameEqual(windowed, df.withColumn("wm", 
mean(df.v).over(w)))
+                    assertDataFrameEqual(windowed, df.withColumn("wm", 
sf.mean(df.v).over(w)))
 
         with self.tempView("v"):
             df.createOrReplaceTempView("v")
@@ -521,7 +502,7 @@ class WindowPandasUDFTestsMixin:
                 ]
             ):
                 with self.subTest(bound=bound, query_no=i):
-                    assertDataFrameEqual(windowed, df.withColumn("wm", 
mean(df.v).over(w)))
+                    assertDataFrameEqual(windowed, df.withColumn("wm", 
sf.mean(df.v).over(w)))
 
         with self.tempView("v"):
             df.createOrReplaceTempView("v")
@@ -608,6 +589,50 @@ class WindowPandasUDFTestsMixin:
                 result = df.select(mean_udf(df["v"]).over(w)).first()[0]
                 assert result == 123
 
+    def test_arrow_batch_slicing(self):
+        df = self.spark.range(1000).select((sf.col("id") % 2).alias("key"), 
sf.col("id").alias("v"))
+
+        w1 = Window.partitionBy("key").orderBy("v")
+        w2 = (
+            Window.partitionBy("key")
+            .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
+            .orderBy("v")
+        )
+
+        @pandas_udf("long", PandasUDFType.GROUPED_AGG)
+        def pandas_sum(v):
+            return v.sum()
+
+        @pandas_udf("long", PandasUDFType.GROUPED_AGG)
+        def pandas_sum_unbounded(v):
+            assert len(v) == 1000 / 2, len(v)
+            return v.sum()
+
+        expected1 = df.select("*", 
sf.sum("v").over(w1).alias("res")).sort("key", "v").collect()
+        expected2 = df.select("*", 
sf.sum("v").over(w2).alias("res")).sort("key", "v").collect()
+
+        for maxRecords, maxBytes in [(10, 2**31 - 1), (0, 64), (10, 64)]:
+            with self.subTest(maxRecords=maxRecords, maxBytes=maxBytes):
+                with self.sql_conf(
+                    {
+                        "spark.sql.execution.arrow.maxRecordsPerBatch": 
maxRecords,
+                        "spark.sql.execution.arrow.maxBytesPerBatch": maxBytes,
+                    }
+                ):
+                    result1 = (
+                        df.select("*", pandas_sum("v").over(w1).alias("res"))
+                        .sort("key", "v")
+                        .collect()
+                    )
+                    self.assertEqual(expected1, result1)
+
+                    result2 = (
+                        df.select("*", 
pandas_sum_unbounded("v").over(w2).alias("res"))
+                        .sort("key", "v")
+                        .collect()
+                    )
+                    self.assertEqual(expected2, result2)
+
 
 class WindowPandasUDFTests(WindowPandasUDFTestsMixin, ReusedSQLTestCase):
     pass
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index c3ba8bc7063c..d94ba8f397bd 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -52,7 +52,6 @@ from pyspark.serializers import (
 from pyspark.sql.conversion import LocalDataToArrowConversion, 
ArrowTableToRowsConversion
 from pyspark.sql.functions import SkipRestOfInputTableException
 from pyspark.sql.pandas.serializers import (
-    AggArrowUDFSerializer,
     ArrowStreamPandasUDFSerializer,
     ArrowStreamPandasUDTFSerializer,
     GroupPandasUDFSerializer,
@@ -67,6 +66,7 @@ from pyspark.sql.pandas.serializers import (
     TransformWithStateInPySparkRowSerializer,
     TransformWithStateInPySparkRowInitStateSerializer,
     ArrowStreamArrowUDFSerializer,
+    ArrowStreamAggArrowUDFSerializer,
     ArrowBatchUDFSerializer,
     ArrowStreamUDTFSerializer,
     ArrowStreamArrowUDTFSerializer,
@@ -2612,11 +2612,15 @@ def read_udfs(pickleSer, infile, eval_type):
             or eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF
         ):
             ser = GroupArrowUDFSerializer(_assign_cols_by_name)
-        elif eval_type == PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF:
-            ser = AggArrowUDFSerializer(timezone, True, _assign_cols_by_name, 
True)
+        elif eval_type in (
+            PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF,
+            PythonEvalType.SQL_WINDOW_AGG_ARROW_UDF,
+        ):
+            ser = ArrowStreamAggArrowUDFSerializer(timezone, True, 
_assign_cols_by_name, True)
         elif eval_type in (
             PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
             PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
+            PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF,
         ):
             ser = GroupPandasUDFSerializer(
                 timezone, safecheck, _assign_cols_by_name, 
int_to_decimal_coercion_enabled
@@ -2703,7 +2707,6 @@ def read_udfs(pickleSer, infile, eval_type):
         elif eval_type in (
             PythonEvalType.SQL_SCALAR_ARROW_UDF,
             PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF,
-            PythonEvalType.SQL_WINDOW_AGG_ARROW_UDF,
         ):
             # Arrow cast and safe check are always enabled
             ser = ArrowStreamArrowUDFSerializer(timezone, True, 
_assign_cols_by_name, True)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowWindowPythonEvaluatorFactory.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowWindowPythonEvaluatorFactory.scala
index 1643a8d3bdb1..82c03b1d0229 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowWindowPythonEvaluatorFactory.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowWindowPythonEvaluatorFactory.scala
@@ -368,7 +368,7 @@ class ArrowWindowPythonEvaluatorFactory(
         }
       }
 
-      val windowFunctionResult = new ArrowPythonWithNamedArgumentRunner(
+      val runner = new ArrowPythonWithNamedArgumentRunner(
         pyFuncs,
         evalType,
         argMetas,
@@ -378,7 +378,9 @@ class ArrowWindowPythonEvaluatorFactory(
         pythonRunnerConf,
         pythonMetrics,
         jobArtifactUUID,
-        profiler).compute(pythonInput, context.partitionId(), context)
+        profiler) with GroupedPythonArrowInput
+
+      val windowFunctionResult = runner.compute(pythonInput, 
context.partitionId(), context)
 
       val joined = new JoinedRow
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to