This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 148d552ecdf6 [SPARK-54212][PYTHON][DOCS] Fix the doctest of 
PandasGroupedOpsMixin.applyInArrow
148d552ecdf6 is described below

commit 148d552ecdf6103b7958aec620af565d0afe54ed
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Thu Nov 6 15:35:40 2025 +0800

    [SPARK-54212][PYTHON][DOCS] Fix the doctest of 
PandasGroupedOpsMixin.applyInArrow
    
    ### What changes were proposed in this pull request?
    Fix the doctest of PandasGroupedOpsMixin.applyInArrow
    
    ### Why are the changes needed?
    improve test coverage by enabling this doctest
    
    ### Does this PR introduce _any_ user-facing change?
    yes, doc-only changes
    
    ### How was this patch tested?
    CI
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #52910 from zhengruifeng/doc_test_grouped_applyinarrow.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 python/pyspark/sql/pandas/group_ops.py | 65 +++++++++++++++++++---------------
 1 file changed, 37 insertions(+), 28 deletions(-)

diff --git a/python/pyspark/sql/pandas/group_ops.py 
b/python/pyspark/sql/pandas/group_ops.py
index ddad0450ec89..842bbe8e41c7 100644
--- a/python/pyspark/sql/pandas/group_ops.py
+++ b/python/pyspark/sql/pandas/group_ops.py
@@ -22,7 +22,6 @@ from pyspark.errors import PySparkTypeError
 from pyspark.util import PythonEvalType
 from pyspark.sql.column import Column
 from pyspark.sql.dataframe import DataFrame
-from pyspark.sql.pandas.typehints import infer_group_arrow_eval_type_from_func
 from pyspark.sql.streaming.state import GroupStateTimeout
 from pyspark.sql.streaming.stateful_processor import StatefulProcessor
 from pyspark.sql.types import StructType
@@ -810,43 +809,46 @@ class PandasGroupedOpsMixin:
 
         Examples
         --------
-        >>> from pyspark.sql.functions import ceil
-        >>> import pyarrow  # doctest: +SKIP
-        >>> import pyarrow.compute as pc  # doctest: +SKIP
+        >>> from pyspark.sql import functions as sf
+        >>> import pyarrow as pa
+        >>> import pyarrow.compute as pc
         >>> df = spark.createDataFrame(
         ...     [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
-        ...     ("id", "v"))  # doctest: +SKIP
+        ...     ("id", "v"))
         >>> def normalize(table):
         ...     v = table.column("v")
         ...     norm = pc.divide(pc.subtract(v, pc.mean(v)), pc.stddev(v, 
ddof=1))
         ...     return table.set_column(1, "v", norm)
         >>> df.groupby("id").applyInArrow(
-        ...     normalize, schema="id long, v double").show()  # doctest: +SKIP
+        ...     normalize, schema="id long, v double"
+        ... ).sort("id", "v").show()
         +---+-------------------+
         | id|                  v|
         +---+-------------------+
-        |  1|-0.7071067811865475|
-        |  1| 0.7071067811865475|
-        |  2|-0.8320502943378437|
-        |  2|-0.2773500981126146|
-        |  2| 1.1094003924504583|
+        |  1|-0.7071067811865...|
+        |  1| 0.7071067811865...|
+        |  2|-0.8320502943378...|
+        |  2|-0.2773500981126...|
+        |  2| 1.1094003924504...|
         +---+-------------------+
 
         The function can also take and return an iterator of 
`pyarrow.RecordBatch` using type
         hints.
 
+        >>> from typing import Iterator, Tuple
         >>> df = spark.createDataFrame(
         ...     [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
-        ...     ("id", "v"))  # doctest: +SKIP
+        ...     ("id", "v"))
         >>> def sum_func(
-        ...     batches: Iterator[pyarrow.RecordBatch]
-        ... ) -> Iterator[pyarrow.RecordBatch]:  # doctest: +SKIP
+        ...     batches: Iterator[pa.RecordBatch]
+        ... ) -> Iterator[pa.RecordBatch]:
         ...     total = 0
         ...     for batch in batches:
         ...         total += pc.sum(batch.column("v")).as_py()
-        ...     yield pyarrow.RecordBatch.from_pydict({"v": [total]})
+        ...     yield pa.RecordBatch.from_pydict({"v": [total]})
         >>> df.groupby("id").applyInArrow(
-        ...     sum_func, schema="v double").show()  # doctest: +SKIP
+        ...     sum_func, schema="v double"
+        ... ).sort("v").show()
         +----+
         |   v|
         +----+
@@ -863,14 +865,15 @@ class PandasGroupedOpsMixin:
 
         >>> df = spark.createDataFrame(
         ...     [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
-        ...     ("id", "v"))  # doctest: +SKIP
+        ...     ("id", "v"))
         >>> def mean_func(key, table):
         ...     # key is a tuple of one pyarrow.Int64Scalar, which is the value
         ...     # of 'id' for the current group
         ...     mean = pc.mean(table.column("v"))
-        ...     return pyarrow.Table.from_pydict({"id": [key[0].as_py()], "v": 
[mean.as_py()]})
+        ...     return pa.Table.from_pydict({"id": [key[0].as_py()], "v": 
[mean.as_py()]})
         >>> df.groupby('id').applyInArrow(
-        ...     mean_func, schema="id long, v double")  # doctest: +SKIP
+        ...     mean_func, schema="id long, v double"
+        ... ).sort("id").show()
         +---+---+
         | id|  v|
         +---+---+
@@ -882,31 +885,33 @@ class PandasGroupedOpsMixin:
         ...     # key is a tuple of two pyarrow.Int64Scalars, which is the 
values
         ...     # of 'id' and 'ceil(df.v / 2)' for the current group
         ...     sum = pc.sum(table.column("v"))
-        ...     return pyarrow.Table.from_pydict({
+        ...     return pa.Table.from_pydict({
         ...         "id": [key[0].as_py()],
         ...         "ceil(v / 2)": [key[1].as_py()],
         ...         "v": [sum.as_py()]
         ...     })
-        >>> df.groupby(df.id, ceil(df.v / 2)).applyInArrow(
-        ...     sum_func, schema="id long, `ceil(v / 2)` long, v 
double").show()  # doctest: +SKIP
+        >>> df.groupby(df.id, sf.ceil(df.v / 2)).applyInArrow(
+        ...     sum_func, schema="id long, `ceil(v / 2)` long, v double"
+        ... ).sort("id", "v").show()
         +---+-----------+----+
         | id|ceil(v / 2)|   v|
         +---+-----------+----+
-        |  2|          5|10.0|
         |  1|          1| 3.0|
-        |  2|          3| 5.0|
         |  2|          2| 3.0|
+        |  2|          3| 5.0|
+        |  2|          5|10.0|
         +---+-----------+----+
 
         >>> def sum_func(
-        ...     key: Tuple[pyarrow.Scalar, ...], batches: 
Iterator[pyarrow.RecordBatch]
-        ... ) -> Iterator[pyarrow.RecordBatch]:  # doctest: +SKIP
+        ...     key: Tuple[pa.Scalar, ...], batches: Iterator[pa.RecordBatch]
+        ... ) -> Iterator[pa.RecordBatch]:
         ...     total = 0
         ...     for batch in batches:
         ...         total += pc.sum(batch.column("v")).as_py()
-        ...     yield pyarrow.RecordBatch.from_pydict({"id": [key[0].as_py()], 
"v": [total]})
+        ...     yield pa.RecordBatch.from_pydict({"id": [key[0].as_py()], "v": 
[total]})
         >>> df.groupby("id").applyInArrow(
-        ...     sum_func, schema="id long, v double").show()  # doctest: +SKIP
+        ...     sum_func, schema="id long, v double"
+        ... ).sort("id").show()
         +---+----+
         | id|   v|
         +---+----+
@@ -929,6 +934,7 @@ class PandasGroupedOpsMixin:
         """
         from pyspark.sql import GroupedData
         from pyspark.sql.functions import pandas_udf
+        from pyspark.sql.pandas.typehints import 
infer_group_arrow_eval_type_from_func
 
         assert isinstance(self, GroupedData)
 
@@ -1200,6 +1206,9 @@ def _test() -> None:
         del pyspark.sql.pandas.group_ops.PandasGroupedOpsMixin.apply.__doc__
         del 
pyspark.sql.pandas.group_ops.PandasGroupedOpsMixin.applyInPandas.__doc__
 
+    if not have_pyarrow:
+        del 
pyspark.sql.pandas.group_ops.PandasGroupedOpsMixin.applyInArrow.__doc__
+
     spark = SparkSession.builder.master("local[4]").appName("sql.pandas.group 
tests").getOrCreate()
     globs["spark"] = spark
     (failure_count, test_count) = doctest.testmod(


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to