This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 148d552ecdf6 [SPARK-54212][PYTHON][DOCS] Fix the doctest of
PandasGroupedOpsMixin.applyInArrow
148d552ecdf6 is described below
commit 148d552ecdf6103b7958aec620af565d0afe54ed
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Thu Nov 6 15:35:40 2025 +0800
[SPARK-54212][PYTHON][DOCS] Fix the doctest of
PandasGroupedOpsMixin.applyInArrow
### What changes were proposed in this pull request?
Fix the doctest of PandasGroupedOpsMixin.applyInArrow
### Why are the changes needed?
improve test coverage by enabling this doctest
### Does this PR introduce _any_ user-facing change?
yes, doc-only changes
### How was this patch tested?
CI
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #52910 from zhengruifeng/doc_test_grouped_applyinarrow.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/sql/pandas/group_ops.py | 65 +++++++++++++++++++---------------
1 file changed, 37 insertions(+), 28 deletions(-)
diff --git a/python/pyspark/sql/pandas/group_ops.py
b/python/pyspark/sql/pandas/group_ops.py
index ddad0450ec89..842bbe8e41c7 100644
--- a/python/pyspark/sql/pandas/group_ops.py
+++ b/python/pyspark/sql/pandas/group_ops.py
@@ -22,7 +22,6 @@ from pyspark.errors import PySparkTypeError
from pyspark.util import PythonEvalType
from pyspark.sql.column import Column
from pyspark.sql.dataframe import DataFrame
-from pyspark.sql.pandas.typehints import infer_group_arrow_eval_type_from_func
from pyspark.sql.streaming.state import GroupStateTimeout
from pyspark.sql.streaming.stateful_processor import StatefulProcessor
from pyspark.sql.types import StructType
@@ -810,43 +809,46 @@ class PandasGroupedOpsMixin:
Examples
--------
- >>> from pyspark.sql.functions import ceil
- >>> import pyarrow # doctest: +SKIP
- >>> import pyarrow.compute as pc # doctest: +SKIP
+ >>> from pyspark.sql import functions as sf
+ >>> import pyarrow as pa
+ >>> import pyarrow.compute as pc
>>> df = spark.createDataFrame(
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
- ... ("id", "v")) # doctest: +SKIP
+ ... ("id", "v"))
>>> def normalize(table):
... v = table.column("v")
... norm = pc.divide(pc.subtract(v, pc.mean(v)), pc.stddev(v,
ddof=1))
... return table.set_column(1, "v", norm)
>>> df.groupby("id").applyInArrow(
- ... normalize, schema="id long, v double").show() # doctest: +SKIP
+ ... normalize, schema="id long, v double"
+ ... ).sort("id", "v").show()
+---+-------------------+
| id| v|
+---+-------------------+
- | 1|-0.7071067811865475|
- | 1| 0.7071067811865475|
- | 2|-0.8320502943378437|
- | 2|-0.2773500981126146|
- | 2| 1.1094003924504583|
+ | 1|-0.7071067811865...|
+ | 1| 0.7071067811865...|
+ | 2|-0.8320502943378...|
+ | 2|-0.2773500981126...|
+ | 2| 1.1094003924504...|
+---+-------------------+
The function can also take and return an iterator of
`pyarrow.RecordBatch` using type
hints.
+ >>> from typing import Iterator, Tuple
>>> df = spark.createDataFrame(
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
- ... ("id", "v")) # doctest: +SKIP
+ ... ("id", "v"))
>>> def sum_func(
- ... batches: Iterator[pyarrow.RecordBatch]
- ... ) -> Iterator[pyarrow.RecordBatch]: # doctest: +SKIP
+ ... batches: Iterator[pa.RecordBatch]
+ ... ) -> Iterator[pa.RecordBatch]:
... total = 0
... for batch in batches:
... total += pc.sum(batch.column("v")).as_py()
- ... yield pyarrow.RecordBatch.from_pydict({"v": [total]})
+ ... yield pa.RecordBatch.from_pydict({"v": [total]})
>>> df.groupby("id").applyInArrow(
- ... sum_func, schema="v double").show() # doctest: +SKIP
+ ... sum_func, schema="v double"
+ ... ).sort("v").show()
+----+
| v|
+----+
@@ -863,14 +865,15 @@ class PandasGroupedOpsMixin:
>>> df = spark.createDataFrame(
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
- ... ("id", "v")) # doctest: +SKIP
+ ... ("id", "v"))
>>> def mean_func(key, table):
... # key is a tuple of one pyarrow.Int64Scalar, which is the value
... # of 'id' for the current group
... mean = pc.mean(table.column("v"))
- ... return pyarrow.Table.from_pydict({"id": [key[0].as_py()], "v":
[mean.as_py()]})
+ ... return pa.Table.from_pydict({"id": [key[0].as_py()], "v":
[mean.as_py()]})
>>> df.groupby('id').applyInArrow(
- ... mean_func, schema="id long, v double") # doctest: +SKIP
+ ... mean_func, schema="id long, v double"
+ ... ).sort("id").show()
+---+---+
| id| v|
+---+---+
@@ -882,31 +885,33 @@ class PandasGroupedOpsMixin:
... # key is a tuple of two pyarrow.Int64Scalars, which is the
values
... # of 'id' and 'ceil(df.v / 2)' for the current group
... sum = pc.sum(table.column("v"))
- ... return pyarrow.Table.from_pydict({
+ ... return pa.Table.from_pydict({
... "id": [key[0].as_py()],
... "ceil(v / 2)": [key[1].as_py()],
... "v": [sum.as_py()]
... })
- >>> df.groupby(df.id, ceil(df.v / 2)).applyInArrow(
- ... sum_func, schema="id long, `ceil(v / 2)` long, v
double").show() # doctest: +SKIP
+ >>> df.groupby(df.id, sf.ceil(df.v / 2)).applyInArrow(
+ ... sum_func, schema="id long, `ceil(v / 2)` long, v double"
+ ... ).sort("id", "v").show()
+---+-----------+----+
| id|ceil(v / 2)| v|
+---+-----------+----+
- | 2| 5|10.0|
| 1| 1| 3.0|
- | 2| 3| 5.0|
| 2| 2| 3.0|
+ | 2| 3| 5.0|
+ | 2| 5|10.0|
+---+-----------+----+
>>> def sum_func(
- ... key: Tuple[pyarrow.Scalar, ...], batches:
Iterator[pyarrow.RecordBatch]
- ... ) -> Iterator[pyarrow.RecordBatch]: # doctest: +SKIP
+ ... key: Tuple[pa.Scalar, ...], batches: Iterator[pa.RecordBatch]
+ ... ) -> Iterator[pa.RecordBatch]:
... total = 0
... for batch in batches:
... total += pc.sum(batch.column("v")).as_py()
- ... yield pyarrow.RecordBatch.from_pydict({"id": [key[0].as_py()],
"v": [total]})
+ ... yield pa.RecordBatch.from_pydict({"id": [key[0].as_py()], "v":
[total]})
>>> df.groupby("id").applyInArrow(
- ... sum_func, schema="id long, v double").show() # doctest: +SKIP
+ ... sum_func, schema="id long, v double"
+ ... ).sort("id").show()
+---+----+
| id| v|
+---+----+
@@ -929,6 +934,7 @@ class PandasGroupedOpsMixin:
"""
from pyspark.sql import GroupedData
from pyspark.sql.functions import pandas_udf
+ from pyspark.sql.pandas.typehints import
infer_group_arrow_eval_type_from_func
assert isinstance(self, GroupedData)
@@ -1200,6 +1206,9 @@ def _test() -> None:
del pyspark.sql.pandas.group_ops.PandasGroupedOpsMixin.apply.__doc__
del
pyspark.sql.pandas.group_ops.PandasGroupedOpsMixin.applyInPandas.__doc__
+ if not have_pyarrow:
+ del
pyspark.sql.pandas.group_ops.PandasGroupedOpsMixin.applyInArrow.__doc__
+
spark = SparkSession.builder.master("local[4]").appName("sql.pandas.group
tests").getOrCreate()
globs["spark"] = spark
(failure_count, test_count) = doctest.testmod(
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]