This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch branch-4.1
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.1 by this push:
new 67f2ac5e6ecd [SPARK-54145][PYTHON][CONNECT] Fix column check of nested
type in numeric aggregation
67f2ac5e6ecd is described below
commit 67f2ac5e6ecdf8083b1ef67460869d8f0bff2790
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Mon Nov 3 13:27:40 2025 -0800
[SPARK-54145][PYTHON][CONNECT] Fix column check of nested type in numeric
aggregation
### What changes were proposed in this pull request?
Fix column check in numeric agg
### Why are the changes needed?
query:
```py
df = spark.createDataFrame(
[
Row(a="a", b=Row(c=1)),
Row(a="a", b=Row(c=2)),
Row(a="a", b=Row(c=3)),
Row(a="b", b=Row(c=4)),
Row(a="b", b=Row(c=5)),
]
)
df.groupBy("a").max("b.c").show()
```
in classic:
```
+---+-------------+
| a|max(b.c AS c)|
+---+-------------+
| a| 3|
| b| 5|
+---+-------------+
```
in connect:
```
---------------------------------------------------------------------------
PySparkTypeError Traceback (most recent call last)
Cell In[2], line 11
1 df = spark.createDataFrame(
2 [
3 Row(a="a", b=Row(c=1)),
(...) 8 ]
9 )
---> 11 df.groupBy("a").max("b.c").show()
File ~/spark/python/pyspark/sql/connect/group.py:203, in
GroupedData.max(self, *cols)
202 def max(self: "GroupedData", *cols: str) -> "DataFrame":
--> 203 return self._numeric_agg("max", list(cols))
File ~/spark/python/pyspark/sql/connect/group.py:175, in
GroupedData._numeric_agg(self, function, cols)
173 invalid_cols = [c for c in cols if c not in numerical_cols]
174 if len(invalid_cols) > 0:
--> 175 raise PySparkTypeError(
176 errorClass="NOT_NUMERIC_COLUMNS",
177 messageParameters={"invalid_columns":
str(invalid_cols)},
178 )
179 agg_cols = cols
180 else:
181 # if no column is provided, then all numerical columns are
selected
PySparkTypeError: [NOT_NUMERIC_COLUMNS] Numeric aggregation function can
only be applied on numeric columns, got ['b.c'].
```
### Does this PR introduce _any_ user-facing change?
yes, above query works after this fix
### How was this patch tested?
added test
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #52844 from zhengruifeng/numeric_agg_nest_col.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
(cherry picked from commit 94e00ca8f5fcdc1d3c430e1b54de8012101f1d8d)
Signed-off-by: Dongjoon Hyun <[email protected]>
---
python/pyspark/sql/connect/group.py | 11 +++++-----
python/pyspark/sql/connect/types.py | 38 +++++++++++++++++++++++++++++-----
python/pyspark/sql/tests/test_group.py | 16 ++++++++++++++
3 files changed, 54 insertions(+), 11 deletions(-)
diff --git a/python/pyspark/sql/connect/group.py
b/python/pyspark/sql/connect/group.py
index 52d280c2c264..d540e721f149 100644
--- a/python/pyspark/sql/connect/group.py
+++ b/python/pyspark/sql/connect/group.py
@@ -158,6 +158,7 @@ class GroupedData:
def _numeric_agg(self, function: str, cols: Sequence[str]) -> "DataFrame":
from pyspark.sql.connect.dataframe import DataFrame
+ from pyspark.sql.connect.types import verify_numeric_col_name
assert isinstance(function, str) and function in ["min", "max", "avg",
"sum"]
@@ -165,12 +166,8 @@ class GroupedData:
schema = self._df.schema
- numerical_cols: List[str] = [
- field.name for field in schema.fields if
isinstance(field.dataType, NumericType)
- ]
-
if len(cols) > 0:
- invalid_cols = [c for c in cols if c not in numerical_cols]
+ invalid_cols = [c for c in cols if not verify_numeric_col_name(c,
schema)]
if len(invalid_cols) > 0:
raise PySparkTypeError(
errorClass="NOT_NUMERIC_COLUMNS",
@@ -179,7 +176,9 @@ class GroupedData:
agg_cols = cols
else:
# if no column is provided, then all numerical columns are selected
- agg_cols = numerical_cols
+ agg_cols = [
+ field.name for field in schema.fields if
isinstance(field.dataType, NumericType)
+ ]
return DataFrame(
plan.Aggregate(
diff --git a/python/pyspark/sql/connect/types.py
b/python/pyspark/sql/connect/types.py
index 8f9e7c0561cc..7e8f76861079 100644
--- a/python/pyspark/sql/connect/types.py
+++ b/python/pyspark/sql/connect/types.py
@@ -48,6 +48,7 @@ from pyspark.sql.types import (
BinaryType,
BooleanType,
NullType,
+ NumericType,
VariantType,
UserDefinedType,
)
@@ -367,15 +368,42 @@ def verify_col_name(name: str, schema: StructType) ->
bool:
if parts is None or len(parts) == 0:
return False
- def _quick_verify(parts: List[str], schema: DataType) -> bool:
+ def _quick_verify(parts: List[str], dt: DataType) -> bool:
if len(parts) == 0:
return True
_schema: Optional[StructType] = None
- if isinstance(schema, StructType):
- _schema = schema
- elif isinstance(schema, ArrayType) and isinstance(schema.elementType,
StructType):
- _schema = schema.elementType
+ if isinstance(dt, StructType):
+ _schema = dt
+ elif isinstance(dt, ArrayType) and isinstance(dt.elementType,
StructType):
+ _schema = dt.elementType
+ else:
+ return False
+
+ part = parts[0]
+ for field in _schema:
+ if field.name == part:
+ return _quick_verify(parts[1:], field.dataType)
+
+ return False
+
+ return _quick_verify(parts, schema)
+
+
+def verify_numeric_col_name(name: str, schema: StructType) -> bool:
+ parts = parse_attr_name(name)
+ if parts is None or len(parts) == 0:
+ return False
+
+ def _quick_verify(parts: List[str], dt: DataType) -> bool:
+ if len(parts) == 0 and isinstance(dt, NumericType):
+ return True
+
+ _schema: Optional[StructType] = None
+ if isinstance(dt, StructType):
+ _schema = dt
+ elif isinstance(dt, ArrayType) and isinstance(dt.elementType,
StructType):
+ _schema = dt.elementType
else:
return False
diff --git a/python/pyspark/sql/tests/test_group.py
b/python/pyspark/sql/tests/test_group.py
index bbc089b00c13..ac868c34a913 100644
--- a/python/pyspark/sql/tests/test_group.py
+++ b/python/pyspark/sql/tests/test_group.py
@@ -126,6 +126,22 @@ class GroupTestsMixin:
with self.assertRaises(IndexError):
df.groupBy(10).agg(sf.sum("b"))
+ def test_numeric_agg_with_nest_type(self):
+ df = self.spark.createDataFrame(
+ [
+ Row(a="a", b=Row(c=1)),
+ Row(a="a", b=Row(c=2)),
+ Row(a="a", b=Row(c=3)),
+ Row(a="b", b=Row(c=4)),
+ Row(a="b", b=Row(c=5)),
+ ]
+ )
+
+ res = df.groupBy("a").max("b.c").sort("a").collect()
+ # [Row(a='a', max(b.c AS c)=3), Row(a='b', max(b.c AS c)=5)]
+
+ self.assertEqual([["a", 3], ["b", 5]], [list(r) for r in res])
+
@unittest.skipIf(not have_pandas, pandas_requirement_message) # type:
ignore
@unittest.skipIf(not have_pyarrow, pyarrow_requirement_message) # type:
ignore
def test_order_by_ordinal(self):
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]