This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-4.1
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.1 by this push:
     new 67f2ac5e6ecd [SPARK-54145][PYTHON][CONNECT] Fix column check of nested 
type in numeric aggregation
67f2ac5e6ecd is described below

commit 67f2ac5e6ecdf8083b1ef67460869d8f0bff2790
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Mon Nov 3 13:27:40 2025 -0800

    [SPARK-54145][PYTHON][CONNECT] Fix column check of nested type in numeric 
aggregation
    
    ### What changes were proposed in this pull request?
    Fix column check in numeric agg
    
    ### Why are the changes needed?
    query:
    
    ```py
            df = spark.createDataFrame(
                [
                    Row(a="a", b=Row(c=1)),
                    Row(a="a", b=Row(c=2)),
                    Row(a="a", b=Row(c=3)),
                    Row(a="b", b=Row(c=4)),
                    Row(a="b", b=Row(c=5)),
                ]
            )
    
            df.groupBy("a").max("b.c").show()
    ```
    
    in classic:
    ```
    +---+-------------+
    |  a|max(b.c AS c)|
    +---+-------------+
    |  a|            3|
    |  b|            5|
    +---+-------------+
    ```
    
    in connect:
    ```
    ---------------------------------------------------------------------------
    PySparkTypeError                          Traceback (most recent call last)
    Cell In[2], line 11
          1 df = spark.createDataFrame(
          2     [
          3         Row(a="a", b=Row(c=1)),
       (...)      8     ]
          9 )
    ---> 11 df.groupBy("a").max("b.c").show()
    
    File ~/spark/python/pyspark/sql/connect/group.py:203, in 
GroupedData.max(self, *cols)
        202 def max(self: "GroupedData", *cols: str) -> "DataFrame":
    --> 203     return self._numeric_agg("max", list(cols))
    
    File ~/spark/python/pyspark/sql/connect/group.py:175, in 
GroupedData._numeric_agg(self, function, cols)
        173     invalid_cols = [c for c in cols if c not in numerical_cols]
        174     if len(invalid_cols) > 0:
    --> 175         raise PySparkTypeError(
        176             errorClass="NOT_NUMERIC_COLUMNS",
        177             messageParameters={"invalid_columns": 
str(invalid_cols)},
        178         )
        179     agg_cols = cols
        180 else:
        181     # if no column is provided, then all numerical columns are 
selected
    
    PySparkTypeError: [NOT_NUMERIC_COLUMNS] Numeric aggregation function can 
only be applied on numeric columns, got ['b.c'].
    
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    yes, above query works after this fix
    
    ### How was this patch tested?
    added test
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #52844 from zhengruifeng/numeric_agg_nest_col.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Dongjoon Hyun <[email protected]>
    (cherry picked from commit 94e00ca8f5fcdc1d3c430e1b54de8012101f1d8d)
    Signed-off-by: Dongjoon Hyun <[email protected]>
---
 python/pyspark/sql/connect/group.py    | 11 +++++-----
 python/pyspark/sql/connect/types.py    | 38 +++++++++++++++++++++++++++++-----
 python/pyspark/sql/tests/test_group.py | 16 ++++++++++++++
 3 files changed, 54 insertions(+), 11 deletions(-)

diff --git a/python/pyspark/sql/connect/group.py 
b/python/pyspark/sql/connect/group.py
index 52d280c2c264..d540e721f149 100644
--- a/python/pyspark/sql/connect/group.py
+++ b/python/pyspark/sql/connect/group.py
@@ -158,6 +158,7 @@ class GroupedData:
 
     def _numeric_agg(self, function: str, cols: Sequence[str]) -> "DataFrame":
         from pyspark.sql.connect.dataframe import DataFrame
+        from pyspark.sql.connect.types import verify_numeric_col_name
 
         assert isinstance(function, str) and function in ["min", "max", "avg", 
"sum"]
 
@@ -165,12 +166,8 @@ class GroupedData:
 
         schema = self._df.schema
 
-        numerical_cols: List[str] = [
-            field.name for field in schema.fields if 
isinstance(field.dataType, NumericType)
-        ]
-
         if len(cols) > 0:
-            invalid_cols = [c for c in cols if c not in numerical_cols]
+            invalid_cols = [c for c in cols if not verify_numeric_col_name(c, 
schema)]
             if len(invalid_cols) > 0:
                 raise PySparkTypeError(
                     errorClass="NOT_NUMERIC_COLUMNS",
@@ -179,7 +176,9 @@ class GroupedData:
             agg_cols = cols
         else:
             # if no column is provided, then all numerical columns are selected
-            agg_cols = numerical_cols
+            agg_cols = [
+                field.name for field in schema.fields if 
isinstance(field.dataType, NumericType)
+            ]
 
         return DataFrame(
             plan.Aggregate(
diff --git a/python/pyspark/sql/connect/types.py 
b/python/pyspark/sql/connect/types.py
index 8f9e7c0561cc..7e8f76861079 100644
--- a/python/pyspark/sql/connect/types.py
+++ b/python/pyspark/sql/connect/types.py
@@ -48,6 +48,7 @@ from pyspark.sql.types import (
     BinaryType,
     BooleanType,
     NullType,
+    NumericType,
     VariantType,
     UserDefinedType,
 )
@@ -367,15 +368,42 @@ def verify_col_name(name: str, schema: StructType) -> 
bool:
     if parts is None or len(parts) == 0:
         return False
 
-    def _quick_verify(parts: List[str], schema: DataType) -> bool:
+    def _quick_verify(parts: List[str], dt: DataType) -> bool:
         if len(parts) == 0:
             return True
 
         _schema: Optional[StructType] = None
-        if isinstance(schema, StructType):
-            _schema = schema
-        elif isinstance(schema, ArrayType) and isinstance(schema.elementType, 
StructType):
-            _schema = schema.elementType
+        if isinstance(dt, StructType):
+            _schema = dt
+        elif isinstance(dt, ArrayType) and isinstance(dt.elementType, 
StructType):
+            _schema = dt.elementType
+        else:
+            return False
+
+        part = parts[0]
+        for field in _schema:
+            if field.name == part:
+                return _quick_verify(parts[1:], field.dataType)
+
+        return False
+
+    return _quick_verify(parts, schema)
+
+
+def verify_numeric_col_name(name: str, schema: StructType) -> bool:
+    parts = parse_attr_name(name)
+    if parts is None or len(parts) == 0:
+        return False
+
+    def _quick_verify(parts: List[str], dt: DataType) -> bool:
+        if len(parts) == 0 and isinstance(dt, NumericType):
+            return True
+
+        _schema: Optional[StructType] = None
+        if isinstance(dt, StructType):
+            _schema = dt
+        elif isinstance(dt, ArrayType) and isinstance(dt.elementType, 
StructType):
+            _schema = dt.elementType
         else:
             return False
 
diff --git a/python/pyspark/sql/tests/test_group.py 
b/python/pyspark/sql/tests/test_group.py
index bbc089b00c13..ac868c34a913 100644
--- a/python/pyspark/sql/tests/test_group.py
+++ b/python/pyspark/sql/tests/test_group.py
@@ -126,6 +126,22 @@ class GroupTestsMixin:
             with self.assertRaises(IndexError):
                 df.groupBy(10).agg(sf.sum("b"))
 
+    def test_numeric_agg_with_nest_type(self):
+        df = self.spark.createDataFrame(
+            [
+                Row(a="a", b=Row(c=1)),
+                Row(a="a", b=Row(c=2)),
+                Row(a="a", b=Row(c=3)),
+                Row(a="b", b=Row(c=4)),
+                Row(a="b", b=Row(c=5)),
+            ]
+        )
+
+        res = df.groupBy("a").max("b.c").sort("a").collect()
+        # [Row(a='a', max(b.c AS c)=3), Row(a='b', max(b.c AS c)=5)]
+
+        self.assertEqual([["a", 3], ["b", 5]], [list(r) for r in res])
+
     @unittest.skipIf(not have_pandas, pandas_requirement_message)  # type: 
ignore
     @unittest.skipIf(not have_pyarrow, pyarrow_requirement_message)  # type: 
ignore
     def test_order_by_ordinal(self):


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to