This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 25c550ef37bc [SPARK-53695][PYTHON][TESTS] Add tests for 0-arg grouped
agg UDF
25c550ef37bc is described below
commit 25c550ef37bcf4658b2d05326bd8563de78da167
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Thu Sep 25 08:10:33 2025 +0900
[SPARK-53695][PYTHON][TESTS] Add tests for 0-arg grouped agg UDF
### What changes were proposed in this pull request?
Add tests for 0-arg vectorized UDF
### Why are the changes needed?
to guard the 0-args cases:
```
In [6]: pandas_udf("double")
...: def mean_udf2() -> float:
...: return 1.0
...:
In [7]: spark.range(10).select(mean_udf2()).show()
+-----------+
|mean_udf2()|
+-----------+
| 1.0|
+-----------+
```
### Does this PR introduce _any_ user-facing change?
no, test-only
### How was this patch tested?
ci
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #52437 from zhengruifeng/grouped_agg_0_arg.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../sql/tests/arrow/test_arrow_udf_grouped_agg.py | 35 ++++++++++++++++++++++
.../tests/pandas/test_pandas_udf_grouped_agg.py | 35 +++++++++++++++++++++-
2 files changed, 69 insertions(+), 1 deletion(-)
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py
b/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py
index d49f341788be..3fe6d28c66a6 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py
@@ -951,6 +951,41 @@ class GroupedAggArrowUDFTestsMixin:
def func_a(a: pa.Array) -> pa.Scalar:
return pa.compute.max(a)
+ def test_0_args(self):
+ import pyarrow as pa
+
+ df = self.spark.range(10).withColumn("k", sf.col("id") % 3)
+
+ @arrow_udf("long", ArrowUDFType.GROUPED_AGG)
+ def arrow_max(v) -> int:
+ return pa.compute.max(v).as_py()
+
+ @arrow_udf("long", ArrowUDFType.GROUPED_AGG)
+ def arrow_lit_1() -> int:
+ return 1
+
+ expected1 = df.select(sf.max("id").alias("res1"),
sf.lit(1).alias("res1"))
+ result1 = df.select(arrow_max("id").alias("res1"),
arrow_lit_1().alias("res1"))
+ self.assertEqual(expected1.collect(), result1.collect())
+
+ expected2 = (
+ df.groupby("k")
+ .agg(
+ sf.max("id").alias("res1"),
+ sf.lit(1).alias("res1"),
+ )
+ .sort("k")
+ )
+ result2 = (
+ df.groupby("k")
+ .agg(
+ arrow_max("id").alias("res1"),
+ arrow_lit_1().alias("res1"),
+ )
+ .sort("k")
+ )
+ self.assertEqual(expected2.collect(), result2.collect())
+
class GroupedAggArrowUDFTests(GroupedAggArrowUDFTestsMixin, ReusedSQLTestCase):
pass
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py
b/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py
index 1059af59f4a8..65f842fa70ad 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py
@@ -19,7 +19,7 @@ import unittest
from typing import cast
from pyspark.util import PythonEvalType
-from pyspark.sql import Row
+from pyspark.sql import Row, functions as sf
from pyspark.sql.functions import (
array,
explode,
@@ -761,6 +761,39 @@ class GroupedAggPandasUDFTestsMixin:
row = df.groupby("id").agg(test(df.id)).first()
self.assertEqual(row[1], 123)
+ def test_0_args(self):
+ df = self.spark.range(10).withColumn("k", sf.col("id") % 3)
+
+ @pandas_udf("long", PandasUDFType.GROUPED_AGG)
+ def pandas_max(v) -> int:
+ return v.max()
+
+ @pandas_udf("long", PandasUDFType.GROUPED_AGG)
+ def pandas_lit_1() -> int:
+ return 1
+
+ expected1 = df.select(sf.max("id").alias("res1"),
sf.lit(1).alias("res1"))
+ result1 = df.select(pandas_max("id").alias("res1"),
pandas_lit_1().alias("res1"))
+ self.assertEqual(expected1.collect(), result1.collect())
+
+ expected2 = (
+ df.groupby("k")
+ .agg(
+ sf.max("id").alias("res1"),
+ sf.lit(1).alias("res1"),
+ )
+ .sort("k")
+ )
+ result2 = (
+ df.groupby("k")
+ .agg(
+ pandas_max("id").alias("res1"),
+ pandas_lit_1().alias("res1"),
+ )
+ .sort("k")
+ )
+ self.assertEqual(expected2.collect(), result2.collect())
+
class GroupedAggPandasUDFTests(GroupedAggPandasUDFTestsMixin,
ReusedSQLTestCase):
pass
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]