This is an automated email from the ASF dual-hosted git repository.
xinrong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new e7071c0237da [SPARK-50435][PYTHON][TESTS] Use assertDataFrameEqual in
pyspark.sql.tests.test_functions
e7071c0237da is described below
commit e7071c0237da75967b2f1e222d9f3b8293a82f86
Author: Xinrong Meng <[email protected]>
AuthorDate: Mon Dec 2 09:59:07 2024 +0800
[SPARK-50435][PYTHON][TESTS] Use assertDataFrameEqual in
pyspark.sql.tests.test_functions
### What changes were proposed in this pull request?
Use `assertDataFrameEqual` in pyspark.sql.tests.test_functions
### Why are the changes needed?
`assertDataFrameEqual` is explicitly built to handle DataFrame-specific
comparisons, including schema.
So we propose to replace `assertEqual` with `assertDataFrameEqual`
Part of https://issues.apache.org/jira/browse/SPARK-50435.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #49011 from xinrong-meng/impr_test_functions.
Lead-authored-by: Xinrong Meng <[email protected]>
Co-authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Xinrong Meng <[email protected]>
---
python/pyspark/sql/tests/test_functions.py | 196 ++++++++++++++---------------
1 file changed, 92 insertions(+), 104 deletions(-)
diff --git a/python/pyspark/sql/tests/test_functions.py
b/python/pyspark/sql/tests/test_functions.py
index e192366676ad..4607d5d3411f 100644
--- a/python/pyspark/sql/tests/test_functions.py
+++ b/python/pyspark/sql/tests/test_functions.py
@@ -31,7 +31,7 @@ from pyspark.sql.avro.functions import from_avro, to_avro
from pyspark.sql.column import Column
from pyspark.sql.functions.builtin import nullifzero, randstr, uniform,
zeroifnull
from pyspark.testing.sqlutils import ReusedSQLTestCase, SQLTestUtils
-from pyspark.testing.utils import have_numpy
+from pyspark.testing.utils import have_numpy, assertDataFrameEqual
class FunctionsTestsMixin:
@@ -344,29 +344,29 @@ class FunctionsTestsMixin:
[("https://spark.apache.org/path?query=1", "QUERY", "query")],
["url", "part", "key"],
)
- actual = df.select(F.try_parse_url(df.url, df.part, df.key)).collect()
- self.assertEqual(actual, [Row("1")])
+ actual = df.select(F.try_parse_url(df.url, df.part, df.key))
+ assertDataFrameEqual(actual, [Row("1")])
df = self.spark.createDataFrame(
[("inva lid://spark.apache.org/path?query=1", "QUERY", "query")],
["url", "part", "key"],
)
- actual = df.select(F.try_parse_url(df.url, df.part, df.key)).collect()
- self.assertEqual(actual, [Row(None)])
+ actual = df.select(F.try_parse_url(df.url, df.part, df.key))
+ assertDataFrameEqual(actual, [Row(None)])
def test_try_make_timestamp(self):
data = [(2024, 5, 22, 10, 30, 0)]
df = self.spark.createDataFrame(data, ["year", "month", "day", "hour",
"minute", "second"])
actual = df.select(
F.try_make_timestamp(df.year, df.month, df.day, df.hour,
df.minute, df.second)
- ).collect()
- self.assertEqual(actual, [Row(datetime.datetime(2024, 5, 22, 10, 30))])
+ )
+ assertDataFrameEqual(actual, [Row(datetime.datetime(2024, 5, 22, 10,
30))])
data = [(2024, 13, 22, 10, 30, 0)]
df = self.spark.createDataFrame(data, ["year", "month", "day", "hour",
"minute", "second"])
actual = df.select(
F.try_make_timestamp(df.year, df.month, df.day, df.hour,
df.minute, df.second)
- ).collect()
- self.assertEqual(actual, [Row(None)])
+ )
+ assertDataFrameEqual(actual, [Row(None)])
def test_try_make_timestamp_ltz(self):
# use local timezone here to avoid flakiness
@@ -378,8 +378,8 @@ class FunctionsTestsMixin:
F.try_make_timestamp_ltz(
df.year, df.month, df.day, df.hour, df.minute, df.second,
df.timezone
)
- ).collect()
- self.assertEqual(actual, [Row(datetime.datetime(2024, 5, 22, 10, 30,
0))])
+ )
+ assertDataFrameEqual(actual, [Row(datetime.datetime(2024, 5, 22, 10,
30, 0))])
# use local timezone here to avoid flakiness
data = [(2024, 13, 22, 10, 30, 0,
datetime.datetime.now().astimezone().tzinfo.__str__())]
@@ -390,23 +390,23 @@ class FunctionsTestsMixin:
F.try_make_timestamp_ltz(
df.year, df.month, df.day, df.hour, df.minute, df.second,
df.timezone
)
- ).collect()
- self.assertEqual(actual, [Row(None)])
+ )
+ assertDataFrameEqual(actual, [Row(None)])
def test_try_make_timestamp_ntz(self):
data = [(2024, 5, 22, 10, 30, 0)]
df = self.spark.createDataFrame(data, ["year", "month", "day", "hour",
"minute", "second"])
actual = df.select(
F.try_make_timestamp_ntz(df.year, df.month, df.day, df.hour,
df.minute, df.second)
- ).collect()
- self.assertEqual(actual, [Row(datetime.datetime(2024, 5, 22, 10, 30))])
+ )
+ assertDataFrameEqual(actual, [Row(datetime.datetime(2024, 5, 22, 10,
30))])
data = [(2024, 13, 22, 10, 30, 0)]
df = self.spark.createDataFrame(data, ["year", "month", "day", "hour",
"minute", "second"])
actual = df.select(
F.try_make_timestamp_ntz(df.year, df.month, df.day, df.hour,
df.minute, df.second)
- ).collect()
- self.assertEqual(actual, [Row(None)])
+ )
+ assertDataFrameEqual(actual, [Row(None)])
def test_string_functions(self):
string_functions = [
@@ -448,51 +448,51 @@ class FunctionsTestsMixin:
)
for name in string_functions:
- self.assertEqual(
- df.select(getattr(F, name)("name")).first()[0],
- df.select(getattr(F, name)(F.col("name"))).first()[0],
+ assertDataFrameEqual(
+ df.select(getattr(F, name)("name")),
+ df.select(getattr(F, name)(F.col("name"))),
)
def test_collation(self):
df = self.spark.createDataFrame([("a",), ("b",)], ["name"])
- actual = df.select(F.collation(F.collate("name",
"UNICODE"))).distinct().collect()
- self.assertEqual([Row("SYSTEM.BUILTIN.UNICODE")], actual)
+ actual = df.select(F.collation(F.collate("name",
"UNICODE"))).distinct()
+ assertDataFrameEqual([Row("SYSTEM.BUILTIN.UNICODE")], actual)
def test_try_make_interval(self):
df = self.spark.createDataFrame([(2147483647,)], ["num"])
- actual = df.select(F.isnull(F.try_make_interval("num"))).collect()
- self.assertEqual([Row(True)], actual)
+ actual = df.select(F.isnull(F.try_make_interval("num")))
+ assertDataFrameEqual([Row(True)], actual)
def test_octet_length_function(self):
# SPARK-36751: add octet length api for python
df = self.spark.createDataFrame([("cat",), ("\U0001F408",)], ["cat"])
- actual = df.select(F.octet_length("cat")).collect()
- self.assertEqual([Row(3), Row(4)], actual)
+ actual = df.select(F.octet_length("cat"))
+ assertDataFrameEqual([Row(3), Row(4)], actual)
def test_bit_length_function(self):
# SPARK-36751: add bit length api for python
df = self.spark.createDataFrame([("cat",), ("\U0001F408",)], ["cat"])
- actual = df.select(F.bit_length("cat")).collect()
- self.assertEqual([Row(24), Row(32)], actual)
+ actual = df.select(F.bit_length("cat"))
+ assertDataFrameEqual([Row(24), Row(32)], actual)
def test_array_contains_function(self):
df = self.spark.createDataFrame([(["1", "2", "3"],), ([],)], ["data"])
- actual = df.select(F.array_contains(df.data, "1").alias("b")).collect()
- self.assertEqual([Row(b=True), Row(b=False)], actual)
+ actual = df.select(F.array_contains(df.data, "1").alias("b"))
+ assertDataFrameEqual([Row(b=True), Row(b=False)], actual)
def test_levenshtein_function(self):
df = self.spark.createDataFrame([("kitten", "sitting")], ["l", "r"])
- actual_without_threshold = df.select(F.levenshtein(df.l,
df.r).alias("b")).collect()
- self.assertEqual([Row(b=3)], actual_without_threshold)
- actual_with_threshold = df.select(F.levenshtein(df.l, df.r,
2).alias("b")).collect()
- self.assertEqual([Row(b=-1)], actual_with_threshold)
+ actual_without_threshold = df.select(F.levenshtein(df.l,
df.r).alias("b"))
+ assertDataFrameEqual([Row(b=3)], actual_without_threshold)
+ actual_with_threshold = df.select(F.levenshtein(df.l, df.r,
2).alias("b"))
+ assertDataFrameEqual([Row(b=-1)], actual_with_threshold)
def test_between_function(self):
df = self.spark.createDataFrame(
[Row(a=1, b=2, c=3), Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)]
)
- self.assertEqual(
- [Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)],
df.filter(df.a.between(df.b, df.c)).collect()
+ assertDataFrameEqual(
+ [Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)],
df.filter(df.a.between(df.b, df.c))
)
def test_dayofweek(self):
@@ -608,7 +608,7 @@ class FunctionsTestsMixin:
F.last(df2.id, False).alias("c"),
F.last(df2.id, True).alias("d"),
)
- self.assertEqual([Row(a=None, b=1, c=None, d=98)], df3.collect())
+ assertDataFrameEqual([Row(a=None, b=1, c=None, d=98)], df3)
def test_approxQuantile(self):
df = self.spark.createDataFrame([Row(a=i, b=i + 10) for i in
range(10)])
@@ -666,20 +666,20 @@ class FunctionsTestsMixin:
df = self.spark.createDataFrame(
[("Tom", 80), (None, 60), ("Alice", 50)], ["name", "height"]
)
- self.assertEqual(
- df.select(df.name).orderBy(F.asc_nulls_first("name")).collect(),
+ assertDataFrameEqual(
+ df.select(df.name).orderBy(F.asc_nulls_first("name")),
[Row(name=None), Row(name="Alice"), Row(name="Tom")],
)
- self.assertEqual(
- df.select(df.name).orderBy(F.asc_nulls_last("name")).collect(),
+ assertDataFrameEqual(
+ df.select(df.name).orderBy(F.asc_nulls_last("name")),
[Row(name="Alice"), Row(name="Tom"), Row(name=None)],
)
- self.assertEqual(
- df.select(df.name).orderBy(F.desc_nulls_first("name")).collect(),
+ assertDataFrameEqual(
+ df.select(df.name).orderBy(F.desc_nulls_first("name")),
[Row(name=None), Row(name="Tom"), Row(name="Alice")],
)
- self.assertEqual(
- df.select(df.name).orderBy(F.desc_nulls_last("name")).collect(),
+ assertDataFrameEqual(
+ df.select(df.name).orderBy(F.desc_nulls_last("name")),
[Row(name="Tom"), Row(name="Alice"), Row(name=None)],
)
@@ -716,20 +716,16 @@ class FunctionsTestsMixin:
)
expected = [Row(sliced=[2, 3]), Row(sliced=[5])]
- self.assertEqual(df.select(F.slice(df.x, 2,
2).alias("sliced")).collect(), expected)
- self.assertEqual(
- df.select(F.slice(df.x, F.lit(2),
F.lit(2)).alias("sliced")).collect(), expected
- )
- self.assertEqual(
- df.select(F.slice("x", "index", "len").alias("sliced")).collect(),
expected
- )
+ assertDataFrameEqual(df.select(F.slice(df.x, 2, 2).alias("sliced")),
expected)
+ assertDataFrameEqual(df.select(F.slice(df.x, F.lit(2),
F.lit(2)).alias("sliced")), expected)
+ assertDataFrameEqual(df.select(F.slice("x", "index",
"len").alias("sliced")), expected)
- self.assertEqual(
- df.select(F.slice(df.x, F.size(df.x) - 1,
F.lit(1)).alias("sliced")).collect(),
+ assertDataFrameEqual(
+ df.select(F.slice(df.x, F.size(df.x) - 1,
F.lit(1)).alias("sliced")),
[Row(sliced=[2]), Row(sliced=[4])],
)
- self.assertEqual(
- df.select(F.slice(df.x, F.lit(1), F.size(df.x) -
1).alias("sliced")).collect(),
+ assertDataFrameEqual(
+ df.select(F.slice(df.x, F.lit(1), F.size(df.x) -
1).alias("sliced")),
[Row(sliced=[1, 2]), Row(sliced=[4])],
)
@@ -738,11 +734,9 @@ class FunctionsTestsMixin:
df = df.withColumn("repeat_n", F.lit(3))
expected = [Row(val=[0, 0, 0])]
- self.assertEqual(df.select(F.array_repeat("id",
3).alias("val")).collect(), expected)
- self.assertEqual(df.select(F.array_repeat("id",
F.lit(3)).alias("val")).collect(), expected)
- self.assertEqual(
- df.select(F.array_repeat("id",
"repeat_n").alias("val")).collect(), expected
- )
+ assertDataFrameEqual(df.select(F.array_repeat("id", 3).alias("val")),
expected)
+ assertDataFrameEqual(df.select(F.array_repeat("id",
F.lit(3)).alias("val")), expected)
+ assertDataFrameEqual(df.select(F.array_repeat("id",
"repeat_n").alias("val")), expected)
def test_input_file_name_udf(self):
df = self.spark.read.text("python/test_support/hello/hello.txt")
@@ -754,11 +748,11 @@ class FunctionsTestsMixin:
df = self.spark.createDataFrame([(1, 4, 3)], ["a", "b", "c"])
expected = [Row(least=1)]
- self.assertEqual(df.select(F.least(df.a, df.b,
df.c).alias("least")).collect(), expected)
- self.assertEqual(
- df.select(F.least(F.lit(3), F.lit(5),
F.lit(1)).alias("least")).collect(), expected
+ assertDataFrameEqual(df.select(F.least(df.a, df.b,
df.c).alias("least")), expected)
+ assertDataFrameEqual(
+ df.select(F.least(F.lit(3), F.lit(5), F.lit(1)).alias("least")),
expected
)
- self.assertEqual(df.select(F.least("a", "b",
"c").alias("least")).collect(), expected)
+ assertDataFrameEqual(df.select(F.least("a", "b", "c").alias("least")),
expected)
with self.assertRaises(PySparkValueError) as pe:
df.select(F.least(df.a).alias("least")).collect()
@@ -800,11 +794,9 @@ class FunctionsTestsMixin:
df = self.spark.createDataFrame([("SPARK_SQL", "CORE", 7, 0)], ("x",
"y", "pos", "len"))
exp = [Row(ol="SPARK_CORESQL")]
- self.assertEqual(df.select(F.overlay(df.x, df.y, 7,
0).alias("ol")).collect(), exp)
- self.assertEqual(
- df.select(F.overlay(df.x, df.y, F.lit(7),
F.lit(0)).alias("ol")).collect(), exp
- )
- self.assertEqual(df.select(F.overlay("x", "y", "pos",
"len").alias("ol")).collect(), exp)
+ assertDataFrameEqual(df.select(F.overlay(df.x, df.y, 7,
0).alias("ol")), exp)
+ assertDataFrameEqual(df.select(F.overlay(df.x, df.y, F.lit(7),
F.lit(0)).alias("ol")), exp)
+ assertDataFrameEqual(df.select(F.overlay("x", "y", "pos",
"len").alias("ol")), exp)
with self.assertRaises(PySparkTypeError) as pe:
df.select(F.overlay(df.x, df.y, 7.5, 0).alias("ol")).collect()
@@ -1164,8 +1156,8 @@ class FunctionsTestsMixin:
def check_assert_true(self, tpe):
df = self.spark.range(3)
- self.assertEqual(
- df.select(F.assert_true(df.id < 3)).toDF("val").collect(),
+ assertDataFrameEqual(
+ df.select(F.assert_true(df.id < 3)).toDF("val"),
[Row(val=None), Row(val=None), Row(val=None)],
)
@@ -1302,17 +1294,17 @@ class FunctionsTestsMixin:
df = self.spark.createDataFrame([([1, 2, 3],), ([],)], ["data"])
for dtype in [np.int8, np.int16, np.int32, np.int64]:
- res = df.select(F.array_contains(df.data,
dtype(1)).alias("b")).collect()
- self.assertEqual([Row(b=True), Row(b=False)], res)
- res = df.select(F.array_position(df.data,
dtype(1)).alias("c")).collect()
- self.assertEqual([Row(c=1), Row(c=0)], res)
+ res = df.select(F.array_contains(df.data, dtype(1)).alias("b"))
+ assertDataFrameEqual([Row(b=True), Row(b=False)], res)
+ res = df.select(F.array_position(df.data, dtype(1)).alias("c"))
+ assertDataFrameEqual([Row(c=1), Row(c=0)], res)
df = self.spark.createDataFrame([([1.0, 2.0, 3.0],), ([],)], ["data"])
for dtype in [np.float32, np.float64]:
- res = df.select(F.array_contains(df.data,
dtype(1)).alias("b")).collect()
- self.assertEqual([Row(b=True), Row(b=False)], res)
- res = df.select(F.array_position(df.data,
dtype(1)).alias("c")).collect()
- self.assertEqual([Row(c=1), Row(c=0)], res)
+ res = df.select(F.array_contains(df.data, dtype(1)).alias("b"))
+ assertDataFrameEqual([Row(b=True), Row(b=False)], res)
+ res = df.select(F.array_position(df.data, dtype(1)).alias("c"))
+ assertDataFrameEqual([Row(c=1), Row(c=0)], res)
@unittest.skipIf(not have_numpy, "NumPy not installed")
def test_ndarray_input(self):
@@ -1729,46 +1721,42 @@ class FunctionsTestsMixin:
def test_nullifzero_zeroifnull(self):
df = self.spark.createDataFrame([(0,), (1,)], ["a"])
- result = df.select(nullifzero(df.a).alias("r")).collect()
- self.assertEqual([Row(r=None), Row(r=1)], result)
+ result = df.select(nullifzero(df.a).alias("r"))
+ assertDataFrameEqual([Row(r=None), Row(r=1)], result)
df = self.spark.createDataFrame([(None,), (1,)], ["a"])
- result = df.select(zeroifnull(df.a).alias("r")).collect()
- self.assertEqual([Row(r=0), Row(r=1)], result)
+ result = df.select(zeroifnull(df.a).alias("r"))
+ assertDataFrameEqual([Row(r=0), Row(r=1)], result)
def test_randstr_uniform(self):
df = self.spark.createDataFrame([(0,)], ["a"])
- result = df.select(randstr(F.lit(5),
F.lit(0)).alias("x")).selectExpr("length(x)").collect()
- self.assertEqual([Row(5)], result)
+ result = df.select(randstr(F.lit(5),
F.lit(0)).alias("x")).selectExpr("length(x)")
+ assertDataFrameEqual([Row(5)], result)
# The random seed is optional.
- result =
df.select(randstr(F.lit(5)).alias("x")).selectExpr("length(x)").collect()
- self.assertEqual([Row(5)], result)
+ result =
df.select(randstr(F.lit(5)).alias("x")).selectExpr("length(x)")
+ assertDataFrameEqual([Row(5)], result)
df = self.spark.createDataFrame([(0,)], ["a"])
- result = (
- df.select(uniform(F.lit(10), F.lit(20), F.lit(0)).alias("x"))
- .selectExpr("x > 5")
- .collect()
- )
- self.assertEqual([Row(True)], result)
+ result = df.select(uniform(F.lit(10), F.lit(20),
F.lit(0)).alias("x")).selectExpr("x > 5")
+ assertDataFrameEqual([Row(True)], result)
# The random seed is optional.
- result = df.select(uniform(F.lit(10),
F.lit(20)).alias("x")).selectExpr("x > 5").collect()
- self.assertEqual([Row(True)], result)
+ result = df.select(uniform(F.lit(10),
F.lit(20)).alias("x")).selectExpr("x > 5")
+ assertDataFrameEqual([Row(True)], result)
def test_string_validation(self):
df = self.spark.createDataFrame([("abc",)], ["a"])
# test is_valid_utf8
- result_is_valid_utf8 =
df.select(F.is_valid_utf8(df.a).alias("r")).collect()
- self.assertEqual([Row(r=True)], result_is_valid_utf8)
+ result_is_valid_utf8 = df.select(F.is_valid_utf8(df.a).alias("r"))
+ assertDataFrameEqual([Row(r=True)], result_is_valid_utf8)
# test make_valid_utf8
- result_make_valid_utf8 =
df.select(F.make_valid_utf8(df.a).alias("r")).collect()
- self.assertEqual([Row(r="abc")], result_make_valid_utf8)
+ result_make_valid_utf8 = df.select(F.make_valid_utf8(df.a).alias("r"))
+ assertDataFrameEqual([Row(r="abc")], result_make_valid_utf8)
# test validate_utf8
- result_validate_utf8 =
df.select(F.validate_utf8(df.a).alias("r")).collect()
- self.assertEqual([Row(r="abc")], result_validate_utf8)
+ result_validate_utf8 = df.select(F.validate_utf8(df.a).alias("r"))
+ assertDataFrameEqual([Row(r="abc")], result_validate_utf8)
# test try_validate_utf8
- result_try_validate_utf8 =
df.select(F.try_validate_utf8(df.a).alias("r")).collect()
- self.assertEqual([Row(r="abc")], result_try_validate_utf8)
+ result_try_validate_utf8 =
df.select(F.try_validate_utf8(df.a).alias("r"))
+ assertDataFrameEqual([Row(r="abc")], result_try_validate_utf8)
class FunctionsTests(ReusedSQLTestCase, FunctionsTestsMixin):
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]