This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 183f9726938e [SPARK-53433][TESTS][FOLLOW-UP] Add a test for aggregation
183f9726938e is described below

commit 183f9726938ed5fe39373fd84f50451ff64fa2b9
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Tue Sep 2 14:49:36 2025 +0900

    [SPARK-53433][TESTS][FOLLOW-UP] Add a test for aggregation
    
    ### What changes were proposed in this pull request?
    Add a test for aggregation
    
    ### Why are the changes needed?
    to improve test coverage
    
    ### Does this PR introduce _any_ user-facing change?
    no, test-only
    
    ### How was this patch tested?
    ci
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #52194 from zhengruifeng/agg_win_var.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../sql/tests/arrow/test_arrow_udf_grouped_agg.py  | 47 +++++++++++++++++++++-
 1 file changed, 46 insertions(+), 1 deletion(-)

diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py 
b/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py
index 3545801c4b5a..f6c3112f94ca 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py
@@ -20,7 +20,14 @@ import unittest
 from pyspark.sql.functions import arrow_udf, ArrowUDFType
 from pyspark.util import PythonEvalType
 from pyspark.sql import Row
-from pyspark.sql.types import ArrayType, YearMonthIntervalType
+from pyspark.sql.types import (
+    ArrayType,
+    YearMonthIntervalType,
+    StructType,
+    StructField,
+    VariantType,
+    VariantVal,
+)
 from pyspark.sql import functions as sf
 from pyspark.errors import AnalysisException, PythonException
 from pyspark.testing.sqlutils import (
@@ -807,6 +814,44 @@ class GroupedAggArrowUDFTestsMixin:
         result2 = df.groupby("i").agg(agg_min_time("t").alias("res")).sort("i")
         self.assertEqual(expected2.collect(), result2.collect())
 
+    def test_input_output_variant(self):
+        import pyarrow as pa
+
+        @arrow_udf("variant")
+        def first_variant(v: pa.Array) -> pa.Scalar:
+            assert isinstance(v, pa.Array)
+            assert isinstance(v, pa.StructArray)
+            assert isinstance(v.field("metadata"), pa.BinaryArray)
+            assert isinstance(v.field("value"), pa.BinaryArray)
+            return v[0]
+
+        @arrow_udf("variant")
+        def last_variant(v: pa.Array) -> pa.Scalar:
+            assert isinstance(v, pa.Array)
+            assert isinstance(v, pa.StructArray)
+            assert isinstance(v.field("metadata"), pa.BinaryArray)
+            assert isinstance(v.field("value"), pa.BinaryArray)
+            return v[-1]
+
+        df = self.spark.range(0, 10).selectExpr("parse_json(cast(id as 
string)) v")
+        result = df.select(
+            first_variant("v").alias("first"),
+            last_variant("v").alias("last"),
+        )
+        self.assertEqual(
+            result.schema,
+            StructType(
+                [
+                    StructField("first", VariantType(), True),
+                    StructField("last", VariantType(), True),
+                ]
+            ),
+        )
+
+        row = result.first()
+        self.assertIsInstance(row.first, VariantVal)
+        self.assertIsInstance(row.last, VariantVal)
+
     def test_return_type_coercion(self):
         import pyarrow as pa
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to