This is an automated email from the ASF dual-hosted git repository.

xinrong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e7071c0237da [SPARK-50435][PYTHON][TESTS] Use assertDataFrameEqual in 
pyspark.sql.tests.test_functions
e7071c0237da is described below

commit e7071c0237da75967b2f1e222d9f3b8293a82f86
Author: Xinrong Meng <[email protected]>
AuthorDate: Mon Dec 2 09:59:07 2024 +0800

    [SPARK-50435][PYTHON][TESTS] Use assertDataFrameEqual in 
pyspark.sql.tests.test_functions
    
    ### What changes were proposed in this pull request?
    Use `assertDataFrameEqual` in pyspark.sql.tests.test_functions
    
    ### Why are the changes needed?
    `assertDataFrameEqual` is explicitly built to handle DataFrame-specific 
comparisons, including schema.
    
    So we propose to replace `assertEqual` with `assertDataFrameEqual`
    
    Part of https://issues.apache.org/jira/browse/SPARK-50435.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Existing tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #49011 from xinrong-meng/impr_test_functions.
    
    Lead-authored-by: Xinrong Meng <[email protected]>
    Co-authored-by: Hyukjin Kwon <[email protected]>
    Signed-off-by: Xinrong Meng <[email protected]>
---
 python/pyspark/sql/tests/test_functions.py | 196 ++++++++++++++---------------
 1 file changed, 92 insertions(+), 104 deletions(-)

diff --git a/python/pyspark/sql/tests/test_functions.py 
b/python/pyspark/sql/tests/test_functions.py
index e192366676ad..4607d5d3411f 100644
--- a/python/pyspark/sql/tests/test_functions.py
+++ b/python/pyspark/sql/tests/test_functions.py
@@ -31,7 +31,7 @@ from pyspark.sql.avro.functions import from_avro, to_avro
 from pyspark.sql.column import Column
 from pyspark.sql.functions.builtin import nullifzero, randstr, uniform, 
zeroifnull
 from pyspark.testing.sqlutils import ReusedSQLTestCase, SQLTestUtils
-from pyspark.testing.utils import have_numpy
+from pyspark.testing.utils import have_numpy, assertDataFrameEqual
 
 
 class FunctionsTestsMixin:
@@ -344,29 +344,29 @@ class FunctionsTestsMixin:
             [("https://spark.apache.org/path?query=1";, "QUERY", "query")],
             ["url", "part", "key"],
         )
-        actual = df.select(F.try_parse_url(df.url, df.part, df.key)).collect()
-        self.assertEqual(actual, [Row("1")])
+        actual = df.select(F.try_parse_url(df.url, df.part, df.key))
+        assertDataFrameEqual(actual, [Row("1")])
         df = self.spark.createDataFrame(
             [("inva lid://spark.apache.org/path?query=1", "QUERY", "query")],
             ["url", "part", "key"],
         )
-        actual = df.select(F.try_parse_url(df.url, df.part, df.key)).collect()
-        self.assertEqual(actual, [Row(None)])
+        actual = df.select(F.try_parse_url(df.url, df.part, df.key))
+        assertDataFrameEqual(actual, [Row(None)])
 
     def test_try_make_timestamp(self):
         data = [(2024, 5, 22, 10, 30, 0)]
         df = self.spark.createDataFrame(data, ["year", "month", "day", "hour", 
"minute", "second"])
         actual = df.select(
             F.try_make_timestamp(df.year, df.month, df.day, df.hour, 
df.minute, df.second)
-        ).collect()
-        self.assertEqual(actual, [Row(datetime.datetime(2024, 5, 22, 10, 30))])
+        )
+        assertDataFrameEqual(actual, [Row(datetime.datetime(2024, 5, 22, 10, 
30))])
 
         data = [(2024, 13, 22, 10, 30, 0)]
         df = self.spark.createDataFrame(data, ["year", "month", "day", "hour", 
"minute", "second"])
         actual = df.select(
             F.try_make_timestamp(df.year, df.month, df.day, df.hour, 
df.minute, df.second)
-        ).collect()
-        self.assertEqual(actual, [Row(None)])
+        )
+        assertDataFrameEqual(actual, [Row(None)])
 
     def test_try_make_timestamp_ltz(self):
         # use local timezone here to avoid flakiness
@@ -378,8 +378,8 @@ class FunctionsTestsMixin:
             F.try_make_timestamp_ltz(
                 df.year, df.month, df.day, df.hour, df.minute, df.second, 
df.timezone
             )
-        ).collect()
-        self.assertEqual(actual, [Row(datetime.datetime(2024, 5, 22, 10, 30, 
0))])
+        )
+        assertDataFrameEqual(actual, [Row(datetime.datetime(2024, 5, 22, 10, 
30, 0))])
 
         # use local timezone here to avoid flakiness
         data = [(2024, 13, 22, 10, 30, 0, 
datetime.datetime.now().astimezone().tzinfo.__str__())]
@@ -390,23 +390,23 @@ class FunctionsTestsMixin:
             F.try_make_timestamp_ltz(
                 df.year, df.month, df.day, df.hour, df.minute, df.second, 
df.timezone
             )
-        ).collect()
-        self.assertEqual(actual, [Row(None)])
+        )
+        assertDataFrameEqual(actual, [Row(None)])
 
     def test_try_make_timestamp_ntz(self):
         data = [(2024, 5, 22, 10, 30, 0)]
         df = self.spark.createDataFrame(data, ["year", "month", "day", "hour", 
"minute", "second"])
         actual = df.select(
             F.try_make_timestamp_ntz(df.year, df.month, df.day, df.hour, 
df.minute, df.second)
-        ).collect()
-        self.assertEqual(actual, [Row(datetime.datetime(2024, 5, 22, 10, 30))])
+        )
+        assertDataFrameEqual(actual, [Row(datetime.datetime(2024, 5, 22, 10, 
30))])
 
         data = [(2024, 13, 22, 10, 30, 0)]
         df = self.spark.createDataFrame(data, ["year", "month", "day", "hour", 
"minute", "second"])
         actual = df.select(
             F.try_make_timestamp_ntz(df.year, df.month, df.day, df.hour, 
df.minute, df.second)
-        ).collect()
-        self.assertEqual(actual, [Row(None)])
+        )
+        assertDataFrameEqual(actual, [Row(None)])
 
     def test_string_functions(self):
         string_functions = [
@@ -448,51 +448,51 @@ class FunctionsTestsMixin:
         )
 
         for name in string_functions:
-            self.assertEqual(
-                df.select(getattr(F, name)("name")).first()[0],
-                df.select(getattr(F, name)(F.col("name"))).first()[0],
+            assertDataFrameEqual(
+                df.select(getattr(F, name)("name")),
+                df.select(getattr(F, name)(F.col("name"))),
             )
 
     def test_collation(self):
         df = self.spark.createDataFrame([("a",), ("b",)], ["name"])
-        actual = df.select(F.collation(F.collate("name", 
"UNICODE"))).distinct().collect()
-        self.assertEqual([Row("SYSTEM.BUILTIN.UNICODE")], actual)
+        actual = df.select(F.collation(F.collate("name", 
"UNICODE"))).distinct()
+        assertDataFrameEqual([Row("SYSTEM.BUILTIN.UNICODE")], actual)
 
     def test_try_make_interval(self):
         df = self.spark.createDataFrame([(2147483647,)], ["num"])
-        actual = df.select(F.isnull(F.try_make_interval("num"))).collect()
-        self.assertEqual([Row(True)], actual)
+        actual = df.select(F.isnull(F.try_make_interval("num")))
+        assertDataFrameEqual([Row(True)], actual)
 
     def test_octet_length_function(self):
         # SPARK-36751: add octet length api for python
         df = self.spark.createDataFrame([("cat",), ("\U0001F408",)], ["cat"])
-        actual = df.select(F.octet_length("cat")).collect()
-        self.assertEqual([Row(3), Row(4)], actual)
+        actual = df.select(F.octet_length("cat"))
+        assertDataFrameEqual([Row(3), Row(4)], actual)
 
     def test_bit_length_function(self):
         # SPARK-36751: add bit length api for python
         df = self.spark.createDataFrame([("cat",), ("\U0001F408",)], ["cat"])
-        actual = df.select(F.bit_length("cat")).collect()
-        self.assertEqual([Row(24), Row(32)], actual)
+        actual = df.select(F.bit_length("cat"))
+        assertDataFrameEqual([Row(24), Row(32)], actual)
 
     def test_array_contains_function(self):
         df = self.spark.createDataFrame([(["1", "2", "3"],), ([],)], ["data"])
-        actual = df.select(F.array_contains(df.data, "1").alias("b")).collect()
-        self.assertEqual([Row(b=True), Row(b=False)], actual)
+        actual = df.select(F.array_contains(df.data, "1").alias("b"))
+        assertDataFrameEqual([Row(b=True), Row(b=False)], actual)
 
     def test_levenshtein_function(self):
         df = self.spark.createDataFrame([("kitten", "sitting")], ["l", "r"])
-        actual_without_threshold = df.select(F.levenshtein(df.l, 
df.r).alias("b")).collect()
-        self.assertEqual([Row(b=3)], actual_without_threshold)
-        actual_with_threshold = df.select(F.levenshtein(df.l, df.r, 
2).alias("b")).collect()
-        self.assertEqual([Row(b=-1)], actual_with_threshold)
+        actual_without_threshold = df.select(F.levenshtein(df.l, 
df.r).alias("b"))
+        assertDataFrameEqual([Row(b=3)], actual_without_threshold)
+        actual_with_threshold = df.select(F.levenshtein(df.l, df.r, 
2).alias("b"))
+        assertDataFrameEqual([Row(b=-1)], actual_with_threshold)
 
     def test_between_function(self):
         df = self.spark.createDataFrame(
             [Row(a=1, b=2, c=3), Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)]
         )
-        self.assertEqual(
-            [Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)], 
df.filter(df.a.between(df.b, df.c)).collect()
+        assertDataFrameEqual(
+            [Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)], 
df.filter(df.a.between(df.b, df.c))
         )
 
     def test_dayofweek(self):
@@ -608,7 +608,7 @@ class FunctionsTestsMixin:
             F.last(df2.id, False).alias("c"),
             F.last(df2.id, True).alias("d"),
         )
-        self.assertEqual([Row(a=None, b=1, c=None, d=98)], df3.collect())
+        assertDataFrameEqual([Row(a=None, b=1, c=None, d=98)], df3)
 
     def test_approxQuantile(self):
         df = self.spark.createDataFrame([Row(a=i, b=i + 10) for i in 
range(10)])
@@ -666,20 +666,20 @@ class FunctionsTestsMixin:
         df = self.spark.createDataFrame(
             [("Tom", 80), (None, 60), ("Alice", 50)], ["name", "height"]
         )
-        self.assertEqual(
-            df.select(df.name).orderBy(F.asc_nulls_first("name")).collect(),
+        assertDataFrameEqual(
+            df.select(df.name).orderBy(F.asc_nulls_first("name")),
             [Row(name=None), Row(name="Alice"), Row(name="Tom")],
         )
-        self.assertEqual(
-            df.select(df.name).orderBy(F.asc_nulls_last("name")).collect(),
+        assertDataFrameEqual(
+            df.select(df.name).orderBy(F.asc_nulls_last("name")),
             [Row(name="Alice"), Row(name="Tom"), Row(name=None)],
         )
-        self.assertEqual(
-            df.select(df.name).orderBy(F.desc_nulls_first("name")).collect(),
+        assertDataFrameEqual(
+            df.select(df.name).orderBy(F.desc_nulls_first("name")),
             [Row(name=None), Row(name="Tom"), Row(name="Alice")],
         )
-        self.assertEqual(
-            df.select(df.name).orderBy(F.desc_nulls_last("name")).collect(),
+        assertDataFrameEqual(
+            df.select(df.name).orderBy(F.desc_nulls_last("name")),
             [Row(name="Tom"), Row(name="Alice"), Row(name=None)],
         )
 
@@ -716,20 +716,16 @@ class FunctionsTestsMixin:
         )
 
         expected = [Row(sliced=[2, 3]), Row(sliced=[5])]
-        self.assertEqual(df.select(F.slice(df.x, 2, 
2).alias("sliced")).collect(), expected)
-        self.assertEqual(
-            df.select(F.slice(df.x, F.lit(2), 
F.lit(2)).alias("sliced")).collect(), expected
-        )
-        self.assertEqual(
-            df.select(F.slice("x", "index", "len").alias("sliced")).collect(), 
expected
-        )
+        assertDataFrameEqual(df.select(F.slice(df.x, 2, 2).alias("sliced")), 
expected)
+        assertDataFrameEqual(df.select(F.slice(df.x, F.lit(2), 
F.lit(2)).alias("sliced")), expected)
+        assertDataFrameEqual(df.select(F.slice("x", "index", 
"len").alias("sliced")), expected)
 
-        self.assertEqual(
-            df.select(F.slice(df.x, F.size(df.x) - 1, 
F.lit(1)).alias("sliced")).collect(),
+        assertDataFrameEqual(
+            df.select(F.slice(df.x, F.size(df.x) - 1, 
F.lit(1)).alias("sliced")),
             [Row(sliced=[2]), Row(sliced=[4])],
         )
-        self.assertEqual(
-            df.select(F.slice(df.x, F.lit(1), F.size(df.x) - 
1).alias("sliced")).collect(),
+        assertDataFrameEqual(
+            df.select(F.slice(df.x, F.lit(1), F.size(df.x) - 
1).alias("sliced")),
             [Row(sliced=[1, 2]), Row(sliced=[4])],
         )
 
@@ -738,11 +734,9 @@ class FunctionsTestsMixin:
         df = df.withColumn("repeat_n", F.lit(3))
 
         expected = [Row(val=[0, 0, 0])]
-        self.assertEqual(df.select(F.array_repeat("id", 
3).alias("val")).collect(), expected)
-        self.assertEqual(df.select(F.array_repeat("id", 
F.lit(3)).alias("val")).collect(), expected)
-        self.assertEqual(
-            df.select(F.array_repeat("id", 
"repeat_n").alias("val")).collect(), expected
-        )
+        assertDataFrameEqual(df.select(F.array_repeat("id", 3).alias("val")), 
expected)
+        assertDataFrameEqual(df.select(F.array_repeat("id", 
F.lit(3)).alias("val")), expected)
+        assertDataFrameEqual(df.select(F.array_repeat("id", 
"repeat_n").alias("val")), expected)
 
     def test_input_file_name_udf(self):
         df = self.spark.read.text("python/test_support/hello/hello.txt")
@@ -754,11 +748,11 @@ class FunctionsTestsMixin:
         df = self.spark.createDataFrame([(1, 4, 3)], ["a", "b", "c"])
 
         expected = [Row(least=1)]
-        self.assertEqual(df.select(F.least(df.a, df.b, 
df.c).alias("least")).collect(), expected)
-        self.assertEqual(
-            df.select(F.least(F.lit(3), F.lit(5), 
F.lit(1)).alias("least")).collect(), expected
+        assertDataFrameEqual(df.select(F.least(df.a, df.b, 
df.c).alias("least")), expected)
+        assertDataFrameEqual(
+            df.select(F.least(F.lit(3), F.lit(5), F.lit(1)).alias("least")), 
expected
         )
-        self.assertEqual(df.select(F.least("a", "b", 
"c").alias("least")).collect(), expected)
+        assertDataFrameEqual(df.select(F.least("a", "b", "c").alias("least")), 
expected)
 
         with self.assertRaises(PySparkValueError) as pe:
             df.select(F.least(df.a).alias("least")).collect()
@@ -800,11 +794,9 @@ class FunctionsTestsMixin:
         df = self.spark.createDataFrame([("SPARK_SQL", "CORE", 7, 0)], ("x", 
"y", "pos", "len"))
 
         exp = [Row(ol="SPARK_CORESQL")]
-        self.assertEqual(df.select(F.overlay(df.x, df.y, 7, 
0).alias("ol")).collect(), exp)
-        self.assertEqual(
-            df.select(F.overlay(df.x, df.y, F.lit(7), 
F.lit(0)).alias("ol")).collect(), exp
-        )
-        self.assertEqual(df.select(F.overlay("x", "y", "pos", 
"len").alias("ol")).collect(), exp)
+        assertDataFrameEqual(df.select(F.overlay(df.x, df.y, 7, 
0).alias("ol")), exp)
+        assertDataFrameEqual(df.select(F.overlay(df.x, df.y, F.lit(7), 
F.lit(0)).alias("ol")), exp)
+        assertDataFrameEqual(df.select(F.overlay("x", "y", "pos", 
"len").alias("ol")), exp)
 
         with self.assertRaises(PySparkTypeError) as pe:
             df.select(F.overlay(df.x, df.y, 7.5, 0).alias("ol")).collect()
@@ -1164,8 +1156,8 @@ class FunctionsTestsMixin:
     def check_assert_true(self, tpe):
         df = self.spark.range(3)
 
-        self.assertEqual(
-            df.select(F.assert_true(df.id < 3)).toDF("val").collect(),
+        assertDataFrameEqual(
+            df.select(F.assert_true(df.id < 3)).toDF("val"),
             [Row(val=None), Row(val=None), Row(val=None)],
         )
 
@@ -1302,17 +1294,17 @@ class FunctionsTestsMixin:
 
         df = self.spark.createDataFrame([([1, 2, 3],), ([],)], ["data"])
         for dtype in [np.int8, np.int16, np.int32, np.int64]:
-            res = df.select(F.array_contains(df.data, 
dtype(1)).alias("b")).collect()
-            self.assertEqual([Row(b=True), Row(b=False)], res)
-            res = df.select(F.array_position(df.data, 
dtype(1)).alias("c")).collect()
-            self.assertEqual([Row(c=1), Row(c=0)], res)
+            res = df.select(F.array_contains(df.data, dtype(1)).alias("b"))
+            assertDataFrameEqual([Row(b=True), Row(b=False)], res)
+            res = df.select(F.array_position(df.data, dtype(1)).alias("c"))
+            assertDataFrameEqual([Row(c=1), Row(c=0)], res)
 
         df = self.spark.createDataFrame([([1.0, 2.0, 3.0],), ([],)], ["data"])
         for dtype in [np.float32, np.float64]:
-            res = df.select(F.array_contains(df.data, 
dtype(1)).alias("b")).collect()
-            self.assertEqual([Row(b=True), Row(b=False)], res)
-            res = df.select(F.array_position(df.data, 
dtype(1)).alias("c")).collect()
-            self.assertEqual([Row(c=1), Row(c=0)], res)
+            res = df.select(F.array_contains(df.data, dtype(1)).alias("b"))
+            assertDataFrameEqual([Row(b=True), Row(b=False)], res)
+            res = df.select(F.array_position(df.data, dtype(1)).alias("c"))
+            assertDataFrameEqual([Row(c=1), Row(c=0)], res)
 
     @unittest.skipIf(not have_numpy, "NumPy not installed")
     def test_ndarray_input(self):
@@ -1729,46 +1721,42 @@ class FunctionsTestsMixin:
 
     def test_nullifzero_zeroifnull(self):
         df = self.spark.createDataFrame([(0,), (1,)], ["a"])
-        result = df.select(nullifzero(df.a).alias("r")).collect()
-        self.assertEqual([Row(r=None), Row(r=1)], result)
+        result = df.select(nullifzero(df.a).alias("r"))
+        assertDataFrameEqual([Row(r=None), Row(r=1)], result)
 
         df = self.spark.createDataFrame([(None,), (1,)], ["a"])
-        result = df.select(zeroifnull(df.a).alias("r")).collect()
-        self.assertEqual([Row(r=0), Row(r=1)], result)
+        result = df.select(zeroifnull(df.a).alias("r"))
+        assertDataFrameEqual([Row(r=0), Row(r=1)], result)
 
     def test_randstr_uniform(self):
         df = self.spark.createDataFrame([(0,)], ["a"])
-        result = df.select(randstr(F.lit(5), 
F.lit(0)).alias("x")).selectExpr("length(x)").collect()
-        self.assertEqual([Row(5)], result)
+        result = df.select(randstr(F.lit(5), 
F.lit(0)).alias("x")).selectExpr("length(x)")
+        assertDataFrameEqual([Row(5)], result)
         # The random seed is optional.
-        result = 
df.select(randstr(F.lit(5)).alias("x")).selectExpr("length(x)").collect()
-        self.assertEqual([Row(5)], result)
+        result = 
df.select(randstr(F.lit(5)).alias("x")).selectExpr("length(x)")
+        assertDataFrameEqual([Row(5)], result)
 
         df = self.spark.createDataFrame([(0,)], ["a"])
-        result = (
-            df.select(uniform(F.lit(10), F.lit(20), F.lit(0)).alias("x"))
-            .selectExpr("x > 5")
-            .collect()
-        )
-        self.assertEqual([Row(True)], result)
+        result = df.select(uniform(F.lit(10), F.lit(20), 
F.lit(0)).alias("x")).selectExpr("x > 5")
+        assertDataFrameEqual([Row(True)], result)
         # The random seed is optional.
-        result = df.select(uniform(F.lit(10), 
F.lit(20)).alias("x")).selectExpr("x > 5").collect()
-        self.assertEqual([Row(True)], result)
+        result = df.select(uniform(F.lit(10), 
F.lit(20)).alias("x")).selectExpr("x > 5")
+        assertDataFrameEqual([Row(True)], result)
 
     def test_string_validation(self):
         df = self.spark.createDataFrame([("abc",)], ["a"])
         # test is_valid_utf8
-        result_is_valid_utf8 = 
df.select(F.is_valid_utf8(df.a).alias("r")).collect()
-        self.assertEqual([Row(r=True)], result_is_valid_utf8)
+        result_is_valid_utf8 = df.select(F.is_valid_utf8(df.a).alias("r"))
+        assertDataFrameEqual([Row(r=True)], result_is_valid_utf8)
         # test make_valid_utf8
-        result_make_valid_utf8 = 
df.select(F.make_valid_utf8(df.a).alias("r")).collect()
-        self.assertEqual([Row(r="abc")], result_make_valid_utf8)
+        result_make_valid_utf8 = df.select(F.make_valid_utf8(df.a).alias("r"))
+        assertDataFrameEqual([Row(r="abc")], result_make_valid_utf8)
         # test validate_utf8
-        result_validate_utf8 = 
df.select(F.validate_utf8(df.a).alias("r")).collect()
-        self.assertEqual([Row(r="abc")], result_validate_utf8)
+        result_validate_utf8 = df.select(F.validate_utf8(df.a).alias("r"))
+        assertDataFrameEqual([Row(r="abc")], result_validate_utf8)
         # test try_validate_utf8
-        result_try_validate_utf8 = 
df.select(F.try_validate_utf8(df.a).alias("r")).collect()
-        self.assertEqual([Row(r="abc")], result_try_validate_utf8)
+        result_try_validate_utf8 = 
df.select(F.try_validate_utf8(df.a).alias("r"))
+        assertDataFrameEqual([Row(r="abc")], result_try_validate_utf8)
 
 
 class FunctionsTests(ReusedSQLTestCase, FunctionsTestsMixin):


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to