This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 532de6476ff5 [SPARK-54790][PYTHON][TESTS] Add UDTF analyze coverage
tests
532de6476ff5 is described below
commit 532de6476ff5d88f1184db67d112b1039b1185c2
Author: Tian Gao <[email protected]>
AuthorDate: Sun Dec 21 10:25:44 2025 +0900
[SPARK-54790][PYTHON][TESTS] Add UDTF analyze coverage tests
### What changes were proposed in this pull request?
Add coverage tests for UDTF analyze worker.
### Why are the changes needed?
Increase test coverage
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Locally passed.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #53554 from gaogaotiantian/udtf-tests.
Authored-by: Tian Gao <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/sql/tests/test_udtf.py | 123 +++++++++++++++++++++++++++++-----
1 file changed, 108 insertions(+), 15 deletions(-)
diff --git a/python/pyspark/sql/tests/test_udtf.py
b/python/pyspark/sql/tests/test_udtf.py
index eeb07400b060..de57c8d0cf38 100644
--- a/python/pyspark/sql/tests/test_udtf.py
+++ b/python/pyspark/sql/tests/test_udtf.py
@@ -48,6 +48,7 @@ from pyspark.sql.functions import (
AnalyzeResult,
OrderingColumn,
PartitioningColumn,
+ SelectedColumn,
SkipRestOfInputTableException,
)
from pyspark.sql.types import (
@@ -1847,6 +1848,25 @@ class BaseUDTFTestsMixin:
):
self.spark.sql("SELECT * FROM test_udtf(1, 'x')").collect()
+ def test_udtf_with_analyze_table_select(self):
+ @udtf
+ class TestUDTF:
+ @staticmethod
+ def analyze(*args, **kwargs) -> AnalyzeResult:
+ return AnalyzeResult(
+ StructType().add("id", IntegerType()),
select=[SelectedColumn("id")]
+ )
+
+ def eval(self, row: Row):
+ assert "value" not in row
+ yield row["id"],
+
+ df = self.spark.createDataFrame([(1, "a"), (2, "b"), (3, "c")], ["id",
"value"])
+ assertDataFrameEqual(
+ TestUDTF(df.asTable()).collect(),
+ [Row(id=1), Row(id=2), Row(id=3)],
+ )
+
def test_udtf_with_both_return_type_and_analyze(self):
class TestUDTF:
@staticmethod
@@ -1905,24 +1925,74 @@ class BaseUDTFTestsMixin:
messageParameters={"name": "TestUDTF"},
)
- def test_udtf_with_analyze_returning_non_struct(self):
- class TestUDTF:
- @staticmethod
- def analyze():
- return StringType()
+ def test_udtf_with_scalar_analyze_returning_wrong_result(self):
+ # (wrong_type, error_message_regex)
+ invalid_results = [
+ (StringType(), r".*AnalyzeResult.*StringType.*"),
+ (AnalyzeResult(StringType()), r".*AnalyzeResult.*schema.*"),
+ (
+ AnalyzeResult(StructType().add("a", StringType()),
withSinglePartition=True),
+ r".*withSinglePartition.*",
+ ),
+ (
+ AnalyzeResult(
+ StructType().add("a", StringType()),
partitionBy=[PartitioningColumn("a")]
+ ),
+ r".*partitionBy.*",
+ ),
+ ]
- def eval(self):
- yield "hello", "world"
+ for wrong_type, error_message_regex in invalid_results:
+ with self.subTest(wrong_type=wrong_type):
- func = udtf(TestUDTF)
+ class TestUDTF:
+ @staticmethod
+ def analyze() -> AnalyzeResult:
+ return wrong_type
- with self.assertRaisesRegex(
- AnalysisException,
- "'analyze' method expects a result of type
pyspark.sql.udtf.AnalyzeResult, "
- "but instead this method returned a value of type: "
- "<class 'pyspark.sql.types.StringType'>",
- ):
- func().collect()
+ def eval(self):
+ yield "hello", "world"
+
+ func = udtf(TestUDTF)
+
+ with self.assertRaisesRegex(AnalysisException,
error_message_regex):
+ func().collect()
+
+ def test_udtf_with_table_analyze_returning_wrong_result(self):
+ invalid_results = [
+ (
+ AnalyzeResult(
+ StructType().add("a", StringType()),
partitionBy=[OrderingColumn("a")]
+ ),
+ r".*partitionBy.*",
+ ),
+ (
+ AnalyzeResult(
+ StructType().add("a", StringType()),
orderBy=[PartitioningColumn("a")]
+ ),
+ r".*orderBy.*",
+ ),
+ (
+ AnalyzeResult(StructType().add("a", StringType()),
select=SelectedColumn("a")),
+ r".*select.*",
+ ),
+ ]
+
+ for wrong_type, error_message_regex in invalid_results:
+ with self.subTest(wrong_type=wrong_type):
+
+ class TestUDTF:
+ @staticmethod
+ def analyze(**kwargs) -> AnalyzeResult:
+ return wrong_type
+
+ def eval(self, **kwargs):
+ yield tuple(value for _, value in
sorted(kwargs.items()))
+
+ func = udtf(TestUDTF)
+
+ with self.assertRaisesRegex(AnalysisException,
error_message_regex):
+ func(a=self.spark.range(3).asTable(), b=lit("x")).collect()
def test_udtf_with_analyze_raising_an_exception(self):
class TestUDTF:
@@ -2829,6 +2899,29 @@ class BaseUDTFTestsMixin:
+ [Row(partition_col=42, count=3, total=3, last=None)],
)
+ def test_udtf_with_analyze_order_by_override_nulls_first(self):
+ for override_nulls_first in [True, False]:
+
+ @udtf
+ class TestUDTF:
+ @staticmethod
+ def analyze(*args, **kwargs) -> AnalyzeResult:
+ return AnalyzeResult(
+ StructType().add("id", IntegerType()),
+ withSinglePartition=True,
+ orderBy=[OrderingColumn("id",
overrideNullsFirst=override_nulls_first)],
+ )
+
+ def eval(self, row: Row):
+ yield row["id"],
+
+ df = self.spark.createDataFrame([(1,), (None,)], ["id"])
+ assertDataFrameEqual(
+ TestUDTF(df.asTable()).collect(),
+ [Row(id=None), Row(id=1)] if override_nulls_first else
[Row(id=1), Row(id=None)],
+ checkRowOrder=True,
+ )
+
def test_udtf_with_prepare_string_from_analyze(self):
@dataclass
class AnalyzeResultWithBuffer(AnalyzeResult):
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]