This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 183f9726938e [SPARK-53433][TESTS][FOLLOW-UP] Add a test for aggregation
183f9726938e is described below
commit 183f9726938ed5fe39373fd84f50451ff64fa2b9
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Tue Sep 2 14:49:36 2025 +0900
[SPARK-53433][TESTS][FOLLOW-UP] Add a test for aggregation
### What changes were proposed in this pull request?
Add a test for aggregation
### Why are the changes needed?
to improve test coverage
### Does this PR introduce _any_ user-facing change?
no, test-only
### How was this patch tested?
ci
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #52194 from zhengruifeng/agg_win_var.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../sql/tests/arrow/test_arrow_udf_grouped_agg.py | 47 +++++++++++++++++++++-
1 file changed, 46 insertions(+), 1 deletion(-)
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py
b/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py
index 3545801c4b5a..f6c3112f94ca 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py
@@ -20,7 +20,14 @@ import unittest
from pyspark.sql.functions import arrow_udf, ArrowUDFType
from pyspark.util import PythonEvalType
from pyspark.sql import Row
-from pyspark.sql.types import ArrayType, YearMonthIntervalType
+from pyspark.sql.types import (
+ ArrayType,
+ YearMonthIntervalType,
+ StructType,
+ StructField,
+ VariantType,
+ VariantVal,
+)
from pyspark.sql import functions as sf
from pyspark.errors import AnalysisException, PythonException
from pyspark.testing.sqlutils import (
@@ -807,6 +814,44 @@ class GroupedAggArrowUDFTestsMixin:
result2 = df.groupby("i").agg(agg_min_time("t").alias("res")).sort("i")
self.assertEqual(expected2.collect(), result2.collect())
+ def test_input_output_variant(self):
+ import pyarrow as pa
+
+ @arrow_udf("variant")
+ def first_variant(v: pa.Array) -> pa.Scalar:
+ assert isinstance(v, pa.Array)
+ assert isinstance(v, pa.StructArray)
+ assert isinstance(v.field("metadata"), pa.BinaryArray)
+ assert isinstance(v.field("value"), pa.BinaryArray)
+ return v[0]
+
+ @arrow_udf("variant")
+ def last_variant(v: pa.Array) -> pa.Scalar:
+ assert isinstance(v, pa.Array)
+ assert isinstance(v, pa.StructArray)
+ assert isinstance(v.field("metadata"), pa.BinaryArray)
+ assert isinstance(v.field("value"), pa.BinaryArray)
+ return v[-1]
+
+ df = self.spark.range(0, 10).selectExpr("parse_json(cast(id as
string)) v")
+ result = df.select(
+ first_variant("v").alias("first"),
+ last_variant("v").alias("last"),
+ )
+ self.assertEqual(
+ result.schema,
+ StructType(
+ [
+ StructField("first", VariantType(), True),
+ StructField("last", VariantType(), True),
+ ]
+ ),
+ )
+
+ row = result.first()
+ self.assertIsInstance(row.first, VariantVal)
+ self.assertIsInstance(row.last, VariantVal)
+
def test_return_type_coercion(self):
import pyarrow as pa
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]