This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 532de6476ff5 [SPARK-54790][PYTHON][TESTS] Add UDTF analyze coverage 
tests
532de6476ff5 is described below

commit 532de6476ff5d88f1184db67d112b1039b1185c2
Author: Tian Gao <[email protected]>
AuthorDate: Sun Dec 21 10:25:44 2025 +0900

    [SPARK-54790][PYTHON][TESTS] Add UDTF analyze coverage tests
    
    ### What changes were proposed in this pull request?
    
    Add coverage tests for UDTF analyze worker.
    
    ### Why are the changes needed?
    
    Increase test coverage
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Locally passed.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #53554 from gaogaotiantian/udtf-tests.
    
    Authored-by: Tian Gao <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/sql/tests/test_udtf.py | 123 +++++++++++++++++++++++++++++-----
 1 file changed, 108 insertions(+), 15 deletions(-)

diff --git a/python/pyspark/sql/tests/test_udtf.py 
b/python/pyspark/sql/tests/test_udtf.py
index eeb07400b060..de57c8d0cf38 100644
--- a/python/pyspark/sql/tests/test_udtf.py
+++ b/python/pyspark/sql/tests/test_udtf.py
@@ -48,6 +48,7 @@ from pyspark.sql.functions import (
     AnalyzeResult,
     OrderingColumn,
     PartitioningColumn,
+    SelectedColumn,
     SkipRestOfInputTableException,
 )
 from pyspark.sql.types import (
@@ -1847,6 +1848,25 @@ class BaseUDTFTestsMixin:
         ):
             self.spark.sql("SELECT * FROM test_udtf(1, 'x')").collect()
 
+    def test_udtf_with_analyze_table_select(self):
+        @udtf
+        class TestUDTF:
+            @staticmethod
+            def analyze(*args, **kwargs) -> AnalyzeResult:
+                return AnalyzeResult(
+                    StructType().add("id", IntegerType()), 
select=[SelectedColumn("id")]
+                )
+
+            def eval(self, row: Row):
+                assert "value" not in row
+                yield row["id"],
+
+        df = self.spark.createDataFrame([(1, "a"), (2, "b"), (3, "c")], ["id", 
"value"])
+        assertDataFrameEqual(
+            TestUDTF(df.asTable()).collect(),
+            [Row(id=1), Row(id=2), Row(id=3)],
+        )
+
     def test_udtf_with_both_return_type_and_analyze(self):
         class TestUDTF:
             @staticmethod
@@ -1905,24 +1925,74 @@ class BaseUDTFTestsMixin:
             messageParameters={"name": "TestUDTF"},
         )
 
-    def test_udtf_with_analyze_returning_non_struct(self):
-        class TestUDTF:
-            @staticmethod
-            def analyze():
-                return StringType()
+    def test_udtf_with_scalar_analyze_returning_wrong_result(self):
+        # (wrong_type, error_message_regex)
+        invalid_results = [
+            (StringType(), r".*AnalyzeResult.*StringType.*"),
+            (AnalyzeResult(StringType()), r".*AnalyzeResult.*schema.*"),
+            (
+                AnalyzeResult(StructType().add("a", StringType()), 
withSinglePartition=True),
+                r".*withSinglePartition.*",
+            ),
+            (
+                AnalyzeResult(
+                    StructType().add("a", StringType()), 
partitionBy=[PartitioningColumn("a")]
+                ),
+                r".*partitionBy.*",
+            ),
+        ]
 
-            def eval(self):
-                yield "hello", "world"
+        for wrong_type, error_message_regex in invalid_results:
+            with self.subTest(wrong_type=wrong_type):
 
-        func = udtf(TestUDTF)
+                class TestUDTF:
+                    @staticmethod
+                    def analyze() -> AnalyzeResult:
+                        return wrong_type
 
-        with self.assertRaisesRegex(
-            AnalysisException,
-            "'analyze' method expects a result of type 
pyspark.sql.udtf.AnalyzeResult, "
-            "but instead this method returned a value of type: "
-            "<class 'pyspark.sql.types.StringType'>",
-        ):
-            func().collect()
+                    def eval(self):
+                        yield "hello", "world"
+
+                func = udtf(TestUDTF)
+
+                with self.assertRaisesRegex(AnalysisException, 
error_message_regex):
+                    func().collect()
+
+    def test_udtf_with_table_analyze_returning_wrong_result(self):
+        invalid_results = [
+            (
+                AnalyzeResult(
+                    StructType().add("a", StringType()), 
partitionBy=[OrderingColumn("a")]
+                ),
+                r".*partitionBy.*",
+            ),
+            (
+                AnalyzeResult(
+                    StructType().add("a", StringType()), 
orderBy=[PartitioningColumn("a")]
+                ),
+                r".*orderBy.*",
+            ),
+            (
+                AnalyzeResult(StructType().add("a", StringType()), 
select=SelectedColumn("a")),
+                r".*select.*",
+            ),
+        ]
+
+        for wrong_type, error_message_regex in invalid_results:
+            with self.subTest(wrong_type=wrong_type):
+
+                class TestUDTF:
+                    @staticmethod
+                    def analyze(**kwargs) -> AnalyzeResult:
+                        return wrong_type
+
+                    def eval(self, **kwargs):
+                        yield tuple(value for _, value in 
sorted(kwargs.items()))
+
+                func = udtf(TestUDTF)
+
+                with self.assertRaisesRegex(AnalysisException, 
error_message_regex):
+                    func(a=self.spark.range(3).asTable(), b=lit("x")).collect()
 
     def test_udtf_with_analyze_raising_an_exception(self):
         class TestUDTF:
@@ -2829,6 +2899,29 @@ class BaseUDTFTestsMixin:
                     + [Row(partition_col=42, count=3, total=3, last=None)],
                 )
 
+    def test_udtf_with_analyze_order_by_override_nulls_first(self):
+        for override_nulls_first in [True, False]:
+
+            @udtf
+            class TestUDTF:
+                @staticmethod
+                def analyze(*args, **kwargs) -> AnalyzeResult:
+                    return AnalyzeResult(
+                        StructType().add("id", IntegerType()),
+                        withSinglePartition=True,
+                        orderBy=[OrderingColumn("id", 
overrideNullsFirst=override_nulls_first)],
+                    )
+
+                def eval(self, row: Row):
+                    yield row["id"],
+
+            df = self.spark.createDataFrame([(1,), (None,)], ["id"])
+            assertDataFrameEqual(
+                TestUDTF(df.asTable()).collect(),
+                [Row(id=None), Row(id=1)] if override_nulls_first else 
[Row(id=1), Row(id=None)],
+                checkRowOrder=True,
+            )
+
     def test_udtf_with_prepare_string_from_analyze(self):
         @dataclass
         class AnalyzeResultWithBuffer(AnalyzeResult):


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to