This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new dc73342db941 [SPARK-50436][PYTHON][TESTS] Use assertDataFrameEqual in
pyspark.sql.tests.test_udf
dc73342db941 is described below
commit dc73342db941a7a202acacc2a7e90ff245192712
Author: Xinrong Meng <[email protected]>
AuthorDate: Mon Dec 2 08:42:08 2024 +0900
[SPARK-50436][PYTHON][TESTS] Use assertDataFrameEqual in
pyspark.sql.tests.test_udf
### What changes were proposed in this pull request?
Use `assertDataFrameEqual` in pyspark.sql.tests.test_udf
### Why are the changes needed?
`assertDataFrameEqual` is explicitly built to handle DataFrame-specific
comparisons, including schema.
So we propose to replace `assertEqual` with `assertDataFrameEqual`
Part of https://issues.apache.org/jira/browse/SPARK-50435.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #49001 from xinrong-meng/impr_test_udf.
Authored-by: Xinrong Meng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/sql/tests/test_udf.py | 26 +++++++++++++-------------
1 file changed, 13 insertions(+), 13 deletions(-)
diff --git a/python/pyspark/sql/tests/test_udf.py
b/python/pyspark/sql/tests/test_udf.py
index 78aa2546128a..819391389237 100644
--- a/python/pyspark/sql/tests/test_udf.py
+++ b/python/pyspark/sql/tests/test_udf.py
@@ -220,7 +220,7 @@ class BaseUDFTestsMixin(object):
right = self.spark.createDataFrame([Row(a=1)])
df = left.join(right, on="a", how="left_outer")
df = df.withColumn("b", udf(lambda x: "x")(df.a))
- self.assertEqual(df.filter('b = "x"').collect(), [Row(a=1, b="x")])
+ assertDataFrameEqual(df.filter('b = "x"'), [Row(a=1, b="x")])
def test_udf_in_filter_on_top_of_join(self):
# regression test for SPARK-18589
@@ -228,7 +228,7 @@ class BaseUDFTestsMixin(object):
right = self.spark.createDataFrame([Row(b=1)])
f = udf(lambda a, b: a == b, BooleanType())
df = left.crossJoin(right).filter(f("a", "b"))
- self.assertEqual(df.collect(), [Row(a=1, b=1)])
+ assertDataFrameEqual(df, [Row(a=1, b=1)])
def test_udf_in_join_condition(self):
# regression test for SPARK-25314
@@ -243,7 +243,7 @@ class BaseUDFTestsMixin(object):
df.collect()
with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
df = left.join(right, f("a", "b"))
- self.assertEqual(df.collect(), [Row(a=1, b=1)])
+ assertDataFrameEqual(df, [Row(a=1, b=1)])
def test_udf_in_left_outer_join_condition(self):
# regression test for SPARK-26147
@@ -256,7 +256,7 @@ class BaseUDFTestsMixin(object):
# The Python UDF only refer to attributes from one side, so it's
evaluable.
df = left.join(right, f("a") == col("b").cast("string"),
how="left_outer")
with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
- self.assertEqual(df.collect(), [Row(a=1, b=1)])
+ assertDataFrameEqual(df, [Row(a=1, b=1)])
def test_udf_and_common_filter_in_join_condition(self):
# regression test for SPARK-25314
@@ -266,7 +266,7 @@ class BaseUDFTestsMixin(object):
f = udf(lambda a, b: a == b, BooleanType())
df = left.join(right, [f("a", "b"), left.a1 == right.b1])
# do not need spark.sql.crossJoin.enabled=true for udf is not the only
join condition.
- self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1, b=1, b1=1, b2=1)])
+ assertDataFrameEqual(df, [Row(a=1, a1=1, a2=1, b=1, b1=1, b2=1)])
def test_udf_not_supported_in_join_condition(self):
# regression test for SPARK-25314
@@ -294,7 +294,7 @@ class BaseUDFTestsMixin(object):
f = udf(lambda a: a, IntegerType())
df = left.join(right, [f("a") == f("b"), left.a1 == right.b1])
- self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1, b=1, b1=1, b2=1)])
+ assertDataFrameEqual(df, [Row(a=1, a1=1, a2=1, b=1, b1=1, b2=1)])
def test_udf_without_arguments(self):
self.spark.catalog.registerFunction("foo", lambda: "bar")
@@ -331,7 +331,7 @@ class BaseUDFTestsMixin(object):
my_filter = udf(lambda a: a < 2, BooleanType())
sel = df.select(col("key"),
col("value")).filter((my_filter(col("key"))) & (df.value < "2"))
- self.assertEqual(sel.collect(), [Row(key=1, value="1")])
+ assertDataFrameEqual(sel, [Row(key=1, value="1")])
def test_udf_with_variant_input(self):
df = self.spark.range(0, 10).selectExpr("parse_json(cast(id as
string)) v")
@@ -461,7 +461,7 @@ class BaseUDFTestsMixin(object):
my_filter = udf(lambda a: a == 1, BooleanType())
sel = df.select(col("key")).distinct().filter(my_filter(col("key")))
- self.assertEqual(sel.collect(), [Row(key=1)])
+ assertDataFrameEqual(sel, [Row(key=1)])
my_copy = udf(lambda x: x, IntegerType())
my_add = udf(lambda a, b: int(a + b), IntegerType())
@@ -471,7 +471,7 @@ class BaseUDFTestsMixin(object):
.agg(sum(my_strlen(col("value"))).alias("s"))
.select(my_add(col("k"), col("s")).alias("t"))
)
- self.assertEqual(sel.collect(), [Row(t=4), Row(t=3)])
+ assertDataFrameEqual(sel, [Row(t=4), Row(t=3)])
def test_udf_in_generate(self):
from pyspark.sql.functions import explode
@@ -505,7 +505,7 @@ class BaseUDFTestsMixin(object):
my_copy = udf(lambda x: x, IntegerType())
df = self.spark.range(10).orderBy("id")
res = df.select(df.id, my_copy(df.id).alias("copy")).limit(1)
- self.assertEqual(res.collect(), [Row(id=0, copy=0)])
+ assertDataFrameEqual(res, [Row(id=0, copy=0)])
def test_udf_registration_returns_udf(self):
df = self.spark.range(10)
@@ -838,12 +838,12 @@ class BaseUDFTestsMixin(object):
for df in [filesource_df, datasource_df, datasource_v2_df]:
result = df.withColumn("c", c1)
expected = df.withColumn("c", lit(2))
- self.assertEqual(expected.collect(), result.collect())
+ assertDataFrameEqual(expected, result)
for df in [filesource_df, datasource_df, datasource_v2_df]:
result = df.withColumn("c", c2)
expected = df.withColumn("c", col("i") + 1)
- self.assertEqual(expected.collect(), result.collect())
+ assertDataFrameEqual(expected, result)
for df in [filesource_df, datasource_df, datasource_v2_df]:
for f in [f1, f2]:
@@ -902,7 +902,7 @@ class BaseUDFTestsMixin(object):
result = self.spark.sql(
"select i from values(0L) as data(i) where i in (select id
from v)"
)
- self.assertEqual(result.collect(), [Row(i=0)])
+ assertDataFrameEqual(result, [Row(i=0)])
def test_udf_globals_not_overwritten(self):
@udf("string")
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]