This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 31d996d4f3d8 [SPARK-54453][PYTHON][TEST] Add more coverage tests to
conversion
31d996d4f3d8 is described below
commit 31d996d4f3d8acd698a791745c01b6dcdccf925d
Author: Tian Gao <[email protected]>
AuthorDate: Tue Dec 16 14:10:08 2025 +0800
[SPARK-54453][PYTHON][TEST] Add more coverage tests to conversion
### What changes were proposed in this pull request?
More coverage tests for `conversion.py`
Also the framework is restructured during this process. It's not easy to
add extra test cases to the previous framework because you need to count the
number.
Now the schema, input and output are grouped together so anyone can just
add new test cases in a single place.
### Why are the changes needed?
To improve coverage and make it easier to add test cases in the future.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
After framework change, confirmed that the coverage is exactly the same as
previous. Then confirmed the coverage rate is better after new test cases.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #53479 from gaogaotiantian/conversion-test.
Authored-by: Tian Gao <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/sql/tests/test_conversion.py | 187 ++++++++++++++++++----------
1 file changed, 121 insertions(+), 66 deletions(-)
diff --git a/python/pyspark/sql/tests/test_conversion.py
b/python/pyspark/sql/tests/test_conversion.py
index ca3b6f6671aa..f62244e1279f 100644
--- a/python/pyspark/sql/tests/test_conversion.py
+++ b/python/pyspark/sql/tests/test_conversion.py
@@ -16,99 +16,142 @@
#
import unittest
-from pyspark.sql.conversion import ArrowTableToRowsConversion,
LocalDataToArrowConversion
+from pyspark.errors import PySparkValueError
+from pyspark.sql.conversion import (
+ ArrowTableToRowsConversion,
+ LocalDataToArrowConversion,
+)
from pyspark.sql.types import (
ArrayType,
BinaryType,
+ GeographyType,
+ GeometryType,
IntegerType,
MapType,
+ NullType,
Row,
StringType,
+ StructField,
StructType,
+ UserDefinedType,
)
from pyspark.testing.objects import ExamplePoint, ExamplePointUDT
from pyspark.testing.utils import have_pyarrow, pyarrow_requirement_message
+class ScoreUDT(UserDefinedType):
+ @classmethod
+ def sqlType(cls):
+ return IntegerType()
+
+ def serialize(self, obj):
+ return obj.score
+
+ def deserialize(self, datum):
+ return Score(datum)
+
+
+class Score:
+ __UDT__ = ScoreUDT()
+
+ def __init__(self, score):
+ self.score = score
+
+ def __eq__(self, other):
+ return self.score == other.score
+
+
@unittest.skipIf(not have_pyarrow, pyarrow_requirement_message)
class ConversionTests(unittest.TestCase):
def test_conversion(self):
data = [
+ # Schema, Test cases (Before, After_If_Different)
+ (NullType(), (None,)),
+ (IntegerType(), (1,), (None,)),
+ ((IntegerType(), {"nullable": False}), (1,)),
+ (StringType(), ("a",)),
+ (BinaryType(), (b"a",)),
+ (GeographyType("ANY"), (None,)),
+ (GeometryType("ANY"), (None,)),
+ (ArrayType(IntegerType()), ([1, None],)),
+ (ArrayType(IntegerType(), containsNull=False), ([1, 2],)),
+ (ArrayType(BinaryType()), ([b"a", b"b"],)),
+ (MapType(StringType(), IntegerType()), ({"a": 1, "b": None},)),
(
- i if i % 2 == 0 else None,
- str(i),
- i,
- str(i).encode(),
- [j if j % 2 == 0 else None for j in range(i)],
- list(range(i)),
- [str(j).encode() for j in range(i)],
- {str(j): j if j % 2 == 0 else None for j in range(i)},
- {str(j): j for j in range(i)},
- {str(j): str(j).encode() for j in range(i)},
- (i if i % 2 == 0 else None, str(i), i, str(i).encode()),
- {"i": i if i % 2 == 0 else None, "s": str(i), "ii": i, "b":
str(i).encode()},
- ExamplePoint(float(i), float(i)),
- )
- for i in range(5)
+ MapType(StringType(), IntegerType(), valueContainsNull=False),
+ ({"a": 1},),
+ ),
+ (MapType(StringType(), BinaryType()), ({"a": b"a"},)),
+ (
+ StructType(
+ [
+ StructField("i", IntegerType()),
+ StructField("i_n", IntegerType()),
+ StructField("ii", IntegerType(), nullable=False),
+ StructField("s", StringType()),
+ StructField("b", BinaryType()),
+ ]
+ ),
+ ((1, None, 1, "a", b"a"), Row(i=1, i_n=None, ii=1, s="a",
b=b"a")),
+ (
+ {"b": b"a", "s": "a", "ii": 1, "in": None, "i": 1},
+ Row(i=1, i_n=None, ii=1, s="a", b=b"a"),
+ ),
+ ),
+ (ExamplePointUDT(), (ExamplePoint(1.0, 1.0),)),
+ (ScoreUDT(), (Score(1),)),
]
- schema = (
- StructType()
- .add("i", IntegerType())
- .add("s", StringType())
- .add("ii", IntegerType(), nullable=False)
- .add("b", BinaryType())
- .add("arr_i", ArrayType(IntegerType()))
- .add("arr_ii", ArrayType(IntegerType(), containsNull=False))
- .add("arr_b", ArrayType(BinaryType()))
- .add("map_i", MapType(StringType(), IntegerType()))
- .add("map_ii", MapType(StringType(), IntegerType(),
valueContainsNull=False))
- .add("map_b", MapType(StringType(), BinaryType()))
- .add(
- "struct_t",
- StructType()
- .add("i", IntegerType())
- .add("s", StringType())
- .add("ii", IntegerType(), nullable=False)
- .add("b", BinaryType()),
- )
- .add(
- "struct_d",
- StructType()
- .add("i", IntegerType())
- .add("s", StringType())
- .add("ii", IntegerType(), nullable=False)
- .add("b", BinaryType()),
- )
- .add("udt", ExamplePointUDT())
- )
- tbl = LocalDataToArrowConversion.convert(data, schema,
use_large_var_types=False)
+ schema = StructType()
+
+ input_row = []
+ expected = []
+
+ index = 0
+ for row_schema, *tests in data:
+ if isinstance(row_schema, tuple):
+ row_schema, kwargs = row_schema
+ else:
+ kwargs = {}
+ for test in tests:
+ if len(test) == 1:
+ before, after = test[0], test[0]
+ else:
+ before, after = test
+ schema.add(f"{row_schema.simpleString()}_{index}", row_schema,
**kwargs)
+ input_row.append(before)
+ expected.append(after)
+ index += 1
+
+ tbl = LocalDataToArrowConversion.convert(
+ [tuple(input_row)], schema, use_large_var_types=False
+ )
actual = ArrowTableToRowsConversion.convert(tbl, schema)
for a, e in zip(
- actual,
- [
- Row(
- i=i if i % 2 == 0 else None,
- s=str(i),
- ii=i,
- b=str(i).encode(),
- arr_i=[j if j % 2 == 0 else None for j in range(i)],
- arr_ii=list(range(i)),
- arr_b=[str(j).encode() for j in range(i)],
- map_i={str(j): j if j % 2 == 0 else None for j in
range(i)},
- map_ii={str(j): j for j in range(i)},
- map_b={str(j): str(j).encode() for j in range(i)},
- struct_t=Row(i=i if i % 2 == 0 else None, s=str(i), ii=i,
b=str(i).encode()),
- struct_d=Row(i=i if i % 2 == 0 else None, s=str(i), ii=i,
b=str(i).encode()),
- udt=ExamplePoint(float(i), float(i)),
- )
- for i in range(5)
- ],
+ actual[0],
+ expected,
):
with self.subTest(expected=e):
self.assertEqual(a, e)
+ def test_none_as_row(self):
+ schema = StructType([StructField("x", IntegerType())])
+ tbl = LocalDataToArrowConversion.convert([None], schema,
use_large_var_types=False)
+ actual = ArrowTableToRowsConversion.convert(tbl, schema)
+ self.assertEqual(actual[0], Row(x=None))
+
+ def test_return_as_tuples(self):
+ schema = StructType([StructField("x", IntegerType())])
+ tbl = LocalDataToArrowConversion.convert([(1,)], schema,
use_large_var_types=False)
+ actual = ArrowTableToRowsConversion.convert(tbl, schema,
return_as_tuples=True)
+ self.assertEqual(actual[0], (1,))
+
+ schema = StructType()
+ tbl = LocalDataToArrowConversion.convert([tuple()], schema,
use_large_var_types=False)
+ actual = ArrowTableToRowsConversion.convert(tbl, schema,
return_as_tuples=True)
+ self.assertEqual(actual[0], tuple())
+
def test_binary_as_bytes_conversion(self):
data = [
(
@@ -146,6 +189,18 @@ class ConversionTests(unittest.TestCase):
# Struct field
self.assertIsInstance(row.struct_b.b, expected_type)
+ def test_invalid_conversion(self):
+ data = [
+ (NullType(), 1),
+ (ArrayType(IntegerType(), containsNull=False), [1, None]),
+ (ArrayType(ScoreUDT(), containsNull=False), [None]),
+ ]
+
+ for row_schema, value in data:
+ schema = StructType([StructField("x", row_schema)])
+ with self.assertRaises(PySparkValueError):
+ LocalDataToArrowConversion.convert([(value,)], schema,
use_large_var_types=False)
+
if __name__ == "__main__":
from pyspark.sql.tests.test_conversion import * # noqa: F401
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]