This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 9ae819840bbf [SPARK-53611][PYTHON] Limit Arrow batch sizes in window
agg UDFs
9ae819840bbf is described below
commit 9ae819840bbf2f6c8fc8e7d978c89f3d11b57b05
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Oct 15 09:59:20 2025 +0800
[SPARK-53611][PYTHON] Limit Arrow batch sizes in window agg UDFs
### What changes were proposed in this pull request?
Limit Arrow batch sizes in window agg UDFs
### Why are the changes needed?
to avoid OOM in the JVM side, by batching the JVM->Python Arrow Batches
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
New tests
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #52608 from zhengruifeng/limit_win_agg.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/sql/pandas/serializers.py | 5 +-
.../sql/tests/arrow/test_arrow_udf_window.py | 46 +++++++
.../sql/tests/pandas/test_pandas_udf_window.py | 141 ++++++++++++---------
python/pyspark/worker.py | 11 +-
.../python/ArrowWindowPythonEvaluatorFactory.scala | 6 +-
5 files changed, 143 insertions(+), 66 deletions(-)
diff --git a/python/pyspark/sql/pandas/serializers.py
b/python/pyspark/sql/pandas/serializers.py
index 323aea3a59ce..89b30668424a 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -1143,7 +1143,8 @@ class
GroupArrowUDFSerializer(ArrowStreamGroupUDFSerializer):
return "GroupArrowUDFSerializer"
-class AggArrowUDFSerializer(ArrowStreamArrowUDFSerializer):
+# Serializer for SQL_GROUPED_AGG_ARROW_UDF and SQL_WINDOW_AGG_ARROW_UDF
+class ArrowStreamAggArrowUDFSerializer(ArrowStreamArrowUDFSerializer):
def __init__(
self,
timezone,
@@ -1183,7 +1184,7 @@ class
AggArrowUDFSerializer(ArrowStreamArrowUDFSerializer):
)
def __repr__(self):
- return "AggArrowUDFSerializer"
+ return "ArrowStreamAggArrowUDFSerializer"
class GroupPandasUDFSerializer(ArrowStreamPandasUDFSerializer):
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py
b/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py
index 1d301597c21a..b3ed4c020ca3 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py
@@ -758,6 +758,52 @@ class WindowArrowUDFTestsMixin:
)
self.assertEqual(expected.collect(), result.collect())
+ def test_arrow_batch_slicing(self):
+ import pyarrow as pa
+
+ df = self.spark.range(1000).select((sf.col("id") % 2).alias("key"),
sf.col("id").alias("v"))
+
+ w1 = Window.partitionBy("key").orderBy("v")
+ w2 = (
+ Window.partitionBy("key")
+ .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
+ .orderBy("v")
+ )
+
+ @arrow_udf("long", ArrowUDFType.GROUPED_AGG)
+ def arrow_sum(v):
+ return pa.compute.sum(v)
+
+ @arrow_udf("long", ArrowUDFType.GROUPED_AGG)
+ def arrow_sum_unbounded(v):
+ assert len(v) == 1000 / 2, len(v)
+ return pa.compute.sum(v)
+
+ expected1 = df.select("*",
sf.sum("v").over(w1).alias("res")).sort("key", "v").collect()
+ expected2 = df.select("*",
sf.sum("v").over(w2).alias("res")).sort("key", "v").collect()
+
+ for maxRecords, maxBytes in [(10, 2**31 - 1), (0, 64), (10, 64)]:
+ with self.subTest(maxRecords=maxRecords, maxBytes=maxBytes):
+ with self.sql_conf(
+ {
+ "spark.sql.execution.arrow.maxRecordsPerBatch":
maxRecords,
+ "spark.sql.execution.arrow.maxBytesPerBatch": maxBytes,
+ }
+ ):
+ result1 = (
+ df.select("*", arrow_sum("v").over(w1).alias("res"))
+ .sort("key", "v")
+ .collect()
+ )
+ self.assertEqual(expected1, result1)
+
+ result2 = (
+ df.select("*",
arrow_sum_unbounded("v").over(w2).alias("res"))
+ .sort("key", "v")
+ .collect()
+ )
+ self.assertEqual(expected2, result2)
+
class WindowArrowUDFTests(WindowArrowUDFTestsMixin, ReusedSQLTestCase):
pass
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py
b/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py
index 2f534b811b34..fbc2b32d1c69 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py
@@ -20,19 +20,8 @@ from typing import cast
from decimal import Decimal
from pyspark.errors import AnalysisException, PythonException
-from pyspark.sql.functions import (
- array,
- explode,
- col,
- lit,
- mean,
- min,
- max,
- rank,
- udf,
- pandas_udf,
- PandasUDFType,
-)
+from pyspark.sql import functions as sf
+from pyspark.sql.functions import udf, pandas_udf, PandasUDFType
from pyspark.sql.window import Window
from pyspark.sql.types import (
DecimalType,
@@ -64,10 +53,10 @@ class WindowPandasUDFTestsMixin:
return (
self.spark.range(10)
.toDF("id")
- .withColumn("vs", array([lit(i * 1.0) + col("id") for i in
range(20, 30)]))
- .withColumn("v", explode(col("vs")))
+ .withColumn("vs", sf.array([sf.lit(i * 1.0) + sf.col("id") for i
in range(20, 30)]))
+ .withColumn("v", sf.explode(sf.col("vs")))
.drop("vs")
- .withColumn("w", lit(1.0))
+ .withColumn("w", sf.lit(1.0))
)
@property
@@ -172,10 +161,10 @@ class WindowPandasUDFTestsMixin:
mean_udf = self.pandas_agg_mean_udf
result1 = df.withColumn("mean_v", mean_udf(df["v"]).over(w))
- expected1 = df.withColumn("mean_v", mean(df["v"]).over(w))
+ expected1 = df.withColumn("mean_v", sf.mean(df["v"]).over(w))
result2 = df.select(mean_udf(df["v"]).over(w))
- expected2 = df.select(mean(df["v"]).over(w))
+ expected2 = df.select(sf.mean(df["v"]).over(w))
assert_frame_equal(expected1.toPandas(), result1.toPandas())
assert_frame_equal(expected2.toPandas(), result2.toPandas())
@@ -191,9 +180,9 @@ class WindowPandasUDFTestsMixin:
)
expected1 = (
- df.withColumn("mean_v", mean(df["v"]).over(w))
- .withColumn("max_v", max(df["v"]).over(w))
- .withColumn("min_w", min(df["w"]).over(w))
+ df.withColumn("mean_v", sf.mean(df["v"]).over(w))
+ .withColumn("max_v", sf.max(df["v"]).over(w))
+ .withColumn("min_w", sf.min(df["w"]).over(w))
)
assert_frame_equal(expected1.toPandas(), result1.toPandas())
@@ -203,7 +192,7 @@ class WindowPandasUDFTestsMixin:
w = self.unbounded_window
result1 = df.withColumn("v", self.pandas_agg_mean_udf(df["v"]).over(w))
- expected1 = df.withColumn("v", mean(df["v"]).over(w))
+ expected1 = df.withColumn("v", sf.mean(df["v"]).over(w))
assert_frame_equal(expected1.toPandas(), result1.toPandas())
@@ -213,7 +202,7 @@ class WindowPandasUDFTestsMixin:
mean_udf = self.pandas_agg_mean_udf
result1 = df.withColumn("v", mean_udf(df["v"] * 2).over(w) + 1)
- expected1 = df.withColumn("v", mean(df["v"] * 2).over(w) + 1)
+ expected1 = df.withColumn("v", sf.mean(df["v"] * 2).over(w) + 1)
assert_frame_equal(expected1.toPandas(), result1.toPandas())
@@ -226,10 +215,10 @@ class WindowPandasUDFTestsMixin:
mean_udf = self.pandas_agg_mean_udf
result1 = df.withColumn("v2",
plus_one(mean_udf(plus_one(df["v"])).over(w)))
- expected1 = df.withColumn("v2",
plus_one(mean(plus_one(df["v"])).over(w)))
+ expected1 = df.withColumn("v2",
plus_one(sf.mean(plus_one(df["v"])).over(w)))
result2 = df.withColumn("v2",
time_two(mean_udf(time_two(df["v"])).over(w)))
- expected2 = df.withColumn("v2",
time_two(mean(time_two(df["v"])).over(w)))
+ expected2 = df.withColumn("v2",
time_two(sf.mean(time_two(df["v"])).over(w)))
assert_frame_equal(expected1.toPandas(), result1.toPandas())
assert_frame_equal(expected2.toPandas(), result2.toPandas())
@@ -240,10 +229,10 @@ class WindowPandasUDFTestsMixin:
mean_udf = self.pandas_agg_mean_udf
result1 = df.withColumn("v2", mean_udf(df["v"]).over(w))
- expected1 = df.withColumn("v2", mean(df["v"]).over(w))
+ expected1 = df.withColumn("v2", sf.mean(df["v"]).over(w))
result2 = df.select(mean_udf(df["v"]).over(w))
- expected2 = df.select(mean(df["v"]).over(w))
+ expected2 = df.select(sf.mean(df["v"]).over(w))
assert_frame_equal(expected1.toPandas(), result1.toPandas())
assert_frame_equal(expected2.toPandas(), result2.toPandas())
@@ -256,26 +245,28 @@ class WindowPandasUDFTestsMixin:
min_udf = self.pandas_agg_min_udf
result1 = df.withColumn("v_diff", max_udf(df["v"]).over(w) -
min_udf(df["v"]).over(w))
- expected1 = df.withColumn("v_diff", max(df["v"]).over(w) -
min(df["v"]).over(w))
+ expected1 = df.withColumn("v_diff", sf.max(df["v"]).over(w) -
sf.min(df["v"]).over(w))
# Test mixing sql window function and window udf in the same expression
- result2 = df.withColumn("v_diff", max_udf(df["v"]).over(w) -
min(df["v"]).over(w))
+ result2 = df.withColumn("v_diff", max_udf(df["v"]).over(w) -
sf.min(df["v"]).over(w))
expected2 = expected1
# Test chaining sql aggregate function and udf
result3 = (
df.withColumn("max_v", max_udf(df["v"]).over(w))
- .withColumn("min_v", min(df["v"]).over(w))
- .withColumn("v_diff", col("max_v") - col("min_v"))
+ .withColumn("min_v", sf.min(df["v"]).over(w))
+ .withColumn("v_diff", sf.col("max_v") - sf.col("min_v"))
.drop("max_v", "min_v")
)
expected3 = expected1
# Test mixing sql window function and udf
result4 = df.withColumn("max_v", max_udf(df["v"]).over(w)).withColumn(
- "rank", rank().over(ow)
+ "rank", sf.rank().over(ow)
+ )
+ expected4 = df.withColumn("max_v", sf.max(df["v"]).over(w)).withColumn(
+ "rank", sf.rank().over(ow)
)
- expected4 = df.withColumn("max_v",
max(df["v"]).over(w)).withColumn("rank", rank().over(ow))
assert_frame_equal(expected1.toPandas(), result1.toPandas())
assert_frame_equal(expected2.toPandas(), result2.toPandas())
@@ -303,8 +294,6 @@ class WindowPandasUDFTestsMixin:
df.withColumn("v2", foo_udf(df["v"]).over(w)).schema
def test_bounded_simple(self):
- from pyspark.sql.functions import mean, max, min, count
-
df = self.data
w1 = self.sliding_row_window
w2 = self.shrinking_range_window
@@ -323,17 +312,15 @@ class WindowPandasUDFTestsMixin:
)
expected1 = (
- df.withColumn("mean_v", mean(plus_one(df["v"])).over(w1))
- .withColumn("count_v", count(df["v"]).over(w2))
- .withColumn("max_v", max(df["v"]).over(w2))
- .withColumn("min_v", min(df["v"]).over(w1))
+ df.withColumn("mean_v", sf.mean(plus_one(df["v"])).over(w1))
+ .withColumn("count_v", sf.count(df["v"]).over(w2))
+ .withColumn("max_v", sf.max(df["v"]).over(w2))
+ .withColumn("min_v", sf.min(df["v"]).over(w1))
)
assert_frame_equal(expected1.toPandas(), result1.toPandas())
def test_growing_window(self):
- from pyspark.sql.functions import mean
-
df = self.data
w1 = self.growing_row_window
w2 = self.growing_range_window
@@ -344,15 +331,13 @@ class WindowPandasUDFTestsMixin:
"m2", mean_udf(df["v"]).over(w2)
)
- expected1 = df.withColumn("m1", mean(df["v"]).over(w1)).withColumn(
- "m2", mean(df["v"]).over(w2)
+ expected1 = df.withColumn("m1", sf.mean(df["v"]).over(w1)).withColumn(
+ "m2", sf.mean(df["v"]).over(w2)
)
assert_frame_equal(expected1.toPandas(), result1.toPandas())
def test_sliding_window(self):
- from pyspark.sql.functions import mean
-
df = self.data
w1 = self.sliding_row_window
w2 = self.sliding_range_window
@@ -363,15 +348,13 @@ class WindowPandasUDFTestsMixin:
"m2", mean_udf(df["v"]).over(w2)
)
- expected1 = df.withColumn("m1", mean(df["v"]).over(w1)).withColumn(
- "m2", mean(df["v"]).over(w2)
+ expected1 = df.withColumn("m1", sf.mean(df["v"]).over(w1)).withColumn(
+ "m2", sf.mean(df["v"]).over(w2)
)
assert_frame_equal(expected1.toPandas(), result1.toPandas())
def test_shrinking_window(self):
- from pyspark.sql.functions import mean
-
df = self.data
w1 = self.shrinking_row_window
w2 = self.shrinking_range_window
@@ -382,15 +365,13 @@ class WindowPandasUDFTestsMixin:
"m2", mean_udf(df["v"]).over(w2)
)
- expected1 = df.withColumn("m1", mean(df["v"]).over(w1)).withColumn(
- "m2", mean(df["v"]).over(w2)
+ expected1 = df.withColumn("m1", sf.mean(df["v"]).over(w1)).withColumn(
+ "m2", sf.mean(df["v"]).over(w2)
)
assert_frame_equal(expected1.toPandas(), result1.toPandas())
def test_bounded_mixed(self):
- from pyspark.sql.functions import mean, max
-
df = self.data
w1 = self.sliding_row_window
w2 = self.unbounded_window
@@ -405,9 +386,9 @@ class WindowPandasUDFTestsMixin:
)
expected1 = (
- df.withColumn("mean_v", mean(df["v"]).over(w1))
- .withColumn("max_v", max(df["v"]).over(w2))
- .withColumn("mean_unbounded_v", mean(df["v"]).over(w1))
+ df.withColumn("mean_v", sf.mean(df["v"]).over(w1))
+ .withColumn("max_v", sf.max(df["v"]).over(w2))
+ .withColumn("mean_unbounded_v", sf.mean(df["v"]).over(w1))
)
assert_frame_equal(expected1.toPandas(), result1.toPandas())
@@ -425,7 +406,7 @@ class WindowPandasUDFTestsMixin:
]
):
with self.subTest(bound=bound, query_no=i):
- assertDataFrameEqual(windowed, df.withColumn("wm",
mean(df.v).over(w)))
+ assertDataFrameEqual(windowed, df.withColumn("wm",
sf.mean(df.v).over(w)))
with self.tempView("v"):
df.createOrReplaceTempView("v")
@@ -521,7 +502,7 @@ class WindowPandasUDFTestsMixin:
]
):
with self.subTest(bound=bound, query_no=i):
- assertDataFrameEqual(windowed, df.withColumn("wm",
mean(df.v).over(w)))
+ assertDataFrameEqual(windowed, df.withColumn("wm",
sf.mean(df.v).over(w)))
with self.tempView("v"):
df.createOrReplaceTempView("v")
@@ -608,6 +589,50 @@ class WindowPandasUDFTestsMixin:
result = df.select(mean_udf(df["v"]).over(w)).first()[0]
assert result == 123
+ def test_arrow_batch_slicing(self):
+ df = self.spark.range(1000).select((sf.col("id") % 2).alias("key"),
sf.col("id").alias("v"))
+
+ w1 = Window.partitionBy("key").orderBy("v")
+ w2 = (
+ Window.partitionBy("key")
+ .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
+ .orderBy("v")
+ )
+
+ @pandas_udf("long", PandasUDFType.GROUPED_AGG)
+ def pandas_sum(v):
+ return v.sum()
+
+ @pandas_udf("long", PandasUDFType.GROUPED_AGG)
+ def pandas_sum_unbounded(v):
+ assert len(v) == 1000 / 2, len(v)
+ return v.sum()
+
+ expected1 = df.select("*",
sf.sum("v").over(w1).alias("res")).sort("key", "v").collect()
+ expected2 = df.select("*",
sf.sum("v").over(w2).alias("res")).sort("key", "v").collect()
+
+ for maxRecords, maxBytes in [(10, 2**31 - 1), (0, 64), (10, 64)]:
+ with self.subTest(maxRecords=maxRecords, maxBytes=maxBytes):
+ with self.sql_conf(
+ {
+ "spark.sql.execution.arrow.maxRecordsPerBatch":
maxRecords,
+ "spark.sql.execution.arrow.maxBytesPerBatch": maxBytes,
+ }
+ ):
+ result1 = (
+ df.select("*", pandas_sum("v").over(w1).alias("res"))
+ .sort("key", "v")
+ .collect()
+ )
+ self.assertEqual(expected1, result1)
+
+ result2 = (
+ df.select("*",
pandas_sum_unbounded("v").over(w2).alias("res"))
+ .sort("key", "v")
+ .collect()
+ )
+ self.assertEqual(expected2, result2)
+
class WindowPandasUDFTests(WindowPandasUDFTestsMixin, ReusedSQLTestCase):
pass
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index c3ba8bc7063c..d94ba8f397bd 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -52,7 +52,6 @@ from pyspark.serializers import (
from pyspark.sql.conversion import LocalDataToArrowConversion,
ArrowTableToRowsConversion
from pyspark.sql.functions import SkipRestOfInputTableException
from pyspark.sql.pandas.serializers import (
- AggArrowUDFSerializer,
ArrowStreamPandasUDFSerializer,
ArrowStreamPandasUDTFSerializer,
GroupPandasUDFSerializer,
@@ -67,6 +66,7 @@ from pyspark.sql.pandas.serializers import (
TransformWithStateInPySparkRowSerializer,
TransformWithStateInPySparkRowInitStateSerializer,
ArrowStreamArrowUDFSerializer,
+ ArrowStreamAggArrowUDFSerializer,
ArrowBatchUDFSerializer,
ArrowStreamUDTFSerializer,
ArrowStreamArrowUDTFSerializer,
@@ -2612,11 +2612,15 @@ def read_udfs(pickleSer, infile, eval_type):
or eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF
):
ser = GroupArrowUDFSerializer(_assign_cols_by_name)
- elif eval_type == PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF:
- ser = AggArrowUDFSerializer(timezone, True, _assign_cols_by_name,
True)
+ elif eval_type in (
+ PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF,
+ PythonEvalType.SQL_WINDOW_AGG_ARROW_UDF,
+ ):
+ ser = ArrowStreamAggArrowUDFSerializer(timezone, True,
_assign_cols_by_name, True)
elif eval_type in (
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
+ PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF,
):
ser = GroupPandasUDFSerializer(
timezone, safecheck, _assign_cols_by_name,
int_to_decimal_coercion_enabled
@@ -2703,7 +2707,6 @@ def read_udfs(pickleSer, infile, eval_type):
elif eval_type in (
PythonEvalType.SQL_SCALAR_ARROW_UDF,
PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF,
- PythonEvalType.SQL_WINDOW_AGG_ARROW_UDF,
):
# Arrow cast and safe check are always enabled
ser = ArrowStreamArrowUDFSerializer(timezone, True,
_assign_cols_by_name, True)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowWindowPythonEvaluatorFactory.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowWindowPythonEvaluatorFactory.scala
index 1643a8d3bdb1..82c03b1d0229 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowWindowPythonEvaluatorFactory.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowWindowPythonEvaluatorFactory.scala
@@ -368,7 +368,7 @@ class ArrowWindowPythonEvaluatorFactory(
}
}
- val windowFunctionResult = new ArrowPythonWithNamedArgumentRunner(
+ val runner = new ArrowPythonWithNamedArgumentRunner(
pyFuncs,
evalType,
argMetas,
@@ -378,7 +378,9 @@ class ArrowWindowPythonEvaluatorFactory(
pythonRunnerConf,
pythonMetrics,
jobArtifactUUID,
- profiler).compute(pythonInput, context.partitionId(), context)
+ profiler) with GroupedPythonArrowInput
+
+ val windowFunctionResult = runner.compute(pythonInput,
context.partitionId(), context)
val joined = new JoinedRow
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]