This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 9f325422459d [SPARK-53963][PYTHON][TESTS] Drop temporary functions in
regular UDF tests
9f325422459d is described below
commit 9f325422459d9d0df01f3a487d49d75a3944ace2
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Tue Oct 21 16:17:57 2025 +0800
[SPARK-53963][PYTHON][TESTS] Drop temporary functions in regular UDF tests
### What changes were proposed in this pull request?
Drop temporary functions in regular UDF tests:
1, introduce `temp_func` in testing utils;
2, apply `temp_func` in regular UDF tests (there are too many places to
change, will do it for pandas/arrow udf separately)
3, rename some temp function names, e.g. `double` -> `double_int` to make
it able to drop
### Why are the changes needed?
to avoid polluting the testing envs
### Does this PR introduce _any_ user-facing change?
no, test-only
### How was this patch tested?
ci
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #52674 from zhengruifeng/with_temp_func_udf.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/sql/tests/test_types.py | 7 +-
python/pyspark/sql/tests/test_udf.py | 434 ++++++++++++++------------
python/pyspark/sql/tests/test_udf_profiler.py | 13 +-
python/pyspark/sql/tests/test_unified_udf.py | 102 +++---
python/pyspark/testing/sqlutils.py | 14 +
python/pyspark/tests/test_memory_profiler.py | 13 +-
6 files changed, 308 insertions(+), 275 deletions(-)
diff --git a/python/pyspark/sql/tests/test_types.py
b/python/pyspark/sql/tests/test_types.py
index 319ff92dd362..be0ff14965e0 100644
--- a/python/pyspark/sql/tests/test_types.py
+++ b/python/pyspark/sql/tests/test_types.py
@@ -1002,9 +1002,10 @@ class TypesTestsMixin:
if x > 0:
return PythonOnlyPoint(float(x), float(x))
- self.spark.catalog.registerFunction("udf", myudf, PythonOnlyUDT())
- rows = [r[0] for r in df.selectExpr("udf(id)").take(2)]
- self.assertEqual(rows, [None, PythonOnlyPoint(1, 1)])
+ with self.temp_func("udf"):
+ self.spark.catalog.registerFunction("udf", myudf, PythonOnlyUDT())
+ rows = [r[0] for r in df.selectExpr("udf(id)").take(2)]
+ self.assertEqual(rows, [None, PythonOnlyPoint(1, 1)])
def test_infer_schema_with_udt(self):
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
diff --git a/python/pyspark/sql/tests/test_udf.py
b/python/pyspark/sql/tests/test_udf.py
index d148d3a2c64a..b1fb42ad11ec 100644
--- a/python/pyspark/sql/tests/test_udf.py
+++ b/python/pyspark/sql/tests/test_udf.py
@@ -82,18 +82,20 @@ class BaseUDFTestsMixin(object):
self.assertEqual(res.agg({"plus_four": "sum"}).collect()[0][0], 85)
def test_udf(self):
- self.spark.catalog.registerFunction("twoArgs", lambda x, y: len(x) +
y, IntegerType())
- [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
- self.assertEqual(row[0], 5)
+ with self.temp_func("twoArgs"):
+ self.spark.catalog.registerFunction("twoArgs", lambda x, y: len(x)
+ y, IntegerType())
+ [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
+ self.assertEqual(row[0], 5)
def test_udf_on_sql_context(self):
from pyspark import SQLContext
- # This is to check if a deprecated 'SQLContext.registerFunction' can
call its alias.
- sqlContext = SQLContext.getOrCreate(self.spark.sparkContext)
- sqlContext.registerFunction("oneArg", lambda x: len(x), IntegerType())
- [row] = sqlContext.sql("SELECT oneArg('test')").collect()
- self.assertEqual(row[0], 4)
+ with self.temp_func("oneArg"):
+ # This is to check if a deprecated 'SQLContext.registerFunction'
can call its alias.
+ sqlContext = SQLContext.getOrCreate(self.spark.sparkContext)
+ sqlContext.registerFunction("oneArg", lambda x: len(x),
IntegerType())
+ [row] = sqlContext.sql("SELECT oneArg('test')").collect()
+ self.assertEqual(row[0], 4)
def test_udf2(self):
with self.tempView("test"):
@@ -103,20 +105,22 @@ class BaseUDFTestsMixin(object):
self.assertEqual(4, res[0])
def test_udf3(self):
- two_args = self.spark.catalog.registerFunction(
- "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y)
- )
- self.assertEqual(two_args.deterministic, True)
- [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
- self.assertEqual(row[0], "5")
+ with self.temp_func("twoArgs"):
+ two_args = self.spark.catalog.registerFunction(
+ "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y)
+ )
+ self.assertEqual(two_args.deterministic, True)
+ [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
+ self.assertEqual(row[0], "5")
def test_udf_registration_return_type_none(self):
- two_args = self.spark.catalog.registerFunction(
- "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y,
"integer"), None
- )
- self.assertEqual(two_args.deterministic, True)
- [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
- self.assertEqual(row[0], 5)
+ with self.temp_func("twoArgs"):
+ two_args = self.spark.catalog.registerFunction(
+ "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y,
"integer"), None
+ )
+ self.assertEqual(two_args.deterministic, True)
+ [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
+ self.assertEqual(row[0], 5)
def test_udf_registration_return_type_not_none(self):
with self.quiet():
@@ -149,21 +153,22 @@ class BaseUDFTestsMixin(object):
def test_nondeterministic_udf2(self):
import random
- random_udf = udf(lambda: random.randint(6, 6),
IntegerType()).asNondeterministic()
- self.assertEqual(random_udf.deterministic, False)
- random_udf1 = self.spark.catalog.registerFunction("randInt",
random_udf)
- self.assertEqual(random_udf1.deterministic, False)
- [row] = self.spark.sql("SELECT randInt()").collect()
- self.assertEqual(row[0], 6)
- [row] = self.spark.range(1).select(random_udf1()).collect()
- self.assertEqual(row[0], 6)
- [row] = self.spark.range(1).select(random_udf()).collect()
- self.assertEqual(row[0], 6)
- # render_doc() reproduces the help() exception without printing output
- pydoc.render_doc(udf(lambda: random.randint(6, 6), IntegerType()))
- pydoc.render_doc(random_udf)
- pydoc.render_doc(random_udf1)
- pydoc.render_doc(udf(lambda x: x).asNondeterministic)
+ with self.temp_func("randInt"):
+ random_udf = udf(lambda: random.randint(6, 6),
IntegerType()).asNondeterministic()
+ self.assertEqual(random_udf.deterministic, False)
+ random_udf1 = self.spark.catalog.registerFunction("randInt",
random_udf)
+ self.assertEqual(random_udf1.deterministic, False)
+ [row] = self.spark.sql("SELECT randInt()").collect()
+ self.assertEqual(row[0], 6)
+ [row] = self.spark.range(1).select(random_udf1()).collect()
+ self.assertEqual(row[0], 6)
+ [row] = self.spark.range(1).select(random_udf()).collect()
+ self.assertEqual(row[0], 6)
+ # render_doc() reproduces the help() exception without printing
output
+ pydoc.render_doc(udf(lambda: random.randint(6, 6), IntegerType()))
+ pydoc.render_doc(random_udf)
+ pydoc.render_doc(random_udf1)
+ pydoc.render_doc(udf(lambda x: x).asNondeterministic)
def test_nondeterministic_udf3(self):
# regression test for SPARK-23233
@@ -194,32 +199,36 @@ class BaseUDFTestsMixin(object):
df.agg(sum(udf_random_col())).collect()
def test_chained_udf(self):
- self.spark.catalog.registerFunction("double_int", lambda x: x + x,
IntegerType())
- try:
+ with self.temp_func("double_int"):
+ self.spark.catalog.registerFunction("double_int", lambda x: x + x,
IntegerType())
[row] = self.spark.sql("SELECT double_int(1)").collect()
self.assertEqual(row[0], 2)
[row] = self.spark.sql("SELECT
double_int(double_int(1))").collect()
self.assertEqual(row[0], 4)
[row] = self.spark.sql("SELECT double_int(double_int(1) +
1)").collect()
self.assertEqual(row[0], 6)
- finally:
- self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS double_int")
def test_single_udf_with_repeated_argument(self):
# regression test for SPARK-20685
- self.spark.catalog.registerFunction("add", lambda x, y: x + y,
IntegerType())
- row = self.spark.sql("SELECT add(1, 1)").first()
- self.assertEqual(tuple(row), (2,))
+ with self.temp_func("add_int"):
+ self.spark.catalog.registerFunction("add_int", lambda x, y: x + y,
IntegerType())
+ row = self.spark.sql("SELECT add_int(1, 1)").first()
+ self.assertEqual(tuple(row), (2,))
def test_multiple_udfs(self):
- self.spark.catalog.registerFunction("double", lambda x: x * 2,
IntegerType())
- [row] = self.spark.sql("SELECT double(1), double(2)").collect()
- self.assertEqual(tuple(row), (2, 4))
- [row] = self.spark.sql("SELECT double(double(1)), double(double(2) +
2)").collect()
- self.assertEqual(tuple(row), (4, 12))
- self.spark.catalog.registerFunction("add", lambda x, y: x + y,
IntegerType())
- [row] = self.spark.sql("SELECT double(add(1, 2)), add(double(2),
1)").collect()
- self.assertEqual(tuple(row), (6, 5))
+ with self.temp_func("double_int"):
+ self.spark.catalog.registerFunction("double_int", lambda x: x * 2,
IntegerType())
+ [row] = self.spark.sql("SELECT double_int(1),
double_int(2)").collect()
+ self.assertEqual(tuple(row), (2, 4))
+ [row] = self.spark.sql(
+ "SELECT double_int(double_int(1)), double_int(double_int(2) +
2)"
+ ).collect()
+ self.assertEqual(tuple(row), (4, 12))
+ self.spark.catalog.registerFunction("add_int", lambda x, y: x + y,
IntegerType())
+ [row] = self.spark.sql(
+ "SELECT double_int(add_int(1, 2)), add_int(double_int(2), 1)"
+ ).collect()
+ self.assertEqual(tuple(row), (6, 5))
def test_udf_in_filter_on_top_of_outer_join(self):
left = self.spark.createDataFrame([Row(a=1)])
@@ -303,12 +312,13 @@ class BaseUDFTestsMixin(object):
assertDataFrameEqual(df, [Row(a=1, a1=1, a2=1, b=1, b1=1, b2=1)])
def test_udf_without_arguments(self):
- self.spark.catalog.registerFunction("foo", lambda: "bar")
- [row] = self.spark.sql("SELECT foo()").collect()
- self.assertEqual(row[0], "bar")
+ with self.temp_func("foo"):
+ self.spark.catalog.registerFunction("foo", lambda: "bar")
+ [row] = self.spark.sql("SELECT foo()").collect()
+ self.assertEqual(row[0], "bar")
def test_udf_with_array_type(self):
- with self.tempView("test"):
+ with self.tempView("test"), self.temp_func("copylist", "maplen"):
self.spark.createDataFrame(
[
([0, 1, 2], {"key": [0, 1, 2, 3, 4]}),
@@ -324,13 +334,14 @@ class BaseUDFTestsMixin(object):
self.assertEqual(1, l2)
def test_broadcast_in_udf(self):
- bar = {"a": "aa", "b": "bb", "c": "abc"}
- foo = self.sc.broadcast(bar)
- self.spark.catalog.registerFunction("MYUDF", lambda x: foo.value[x] if
x else "")
- [res] = self.spark.sql("SELECT MYUDF('c')").collect()
- self.assertEqual("abc", res[0])
- [res] = self.spark.sql("SELECT MYUDF('')").collect()
- self.assertEqual("", res[0])
+ with self.temp_func("MYUDF"):
+ bar = {"a": "aa", "b": "bb", "c": "abc"}
+ foo = self.sc.broadcast(bar)
+ self.spark.catalog.registerFunction("MYUDF", lambda x:
foo.value[x] if x else "")
+ [res] = self.spark.sql("SELECT MYUDF('c')").collect()
+ self.assertEqual("abc", res[0])
+ [res] = self.spark.sql("SELECT MYUDF('')").collect()
+ self.assertEqual("", res[0])
def test_udf_with_filter_function(self):
df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1,
"2")], ["key", "value"])
@@ -515,75 +526,78 @@ class BaseUDFTestsMixin(object):
def test_udf_registration_returns_udf(self):
df = self.spark.range(10)
- add_three = self.spark.udf.register("add_three", lambda x: x + 3,
IntegerType())
- self.assertListEqual(
- df.selectExpr("add_three(id) AS plus_three").collect(),
- df.select(add_three("id").alias("plus_three")).collect(),
- )
- add_three_str = self.spark.udf.register("add_three_str", lambda x: x +
3)
- self.assertListEqual(
- df.selectExpr("add_three_str(id) AS plus_three").collect(),
- df.select(add_three_str("id").alias("plus_three")).collect(),
- )
+ with self.temp_func("add_three"):
+ add_three = self.spark.udf.register("add_three", lambda x: x + 3,
IntegerType())
+ self.assertListEqual(
+ df.selectExpr("add_three(id) AS plus_three").collect(),
+ df.select(add_three("id").alias("plus_three")).collect(),
+ )
+
+ with self.temp_func("add_three_str"):
+ add_three_str = self.spark.udf.register("add_three_str", lambda x:
x + 3)
+ self.assertListEqual(
+ df.selectExpr("add_three_str(id) AS plus_three").collect(),
+ df.select(add_three_str("id").alias("plus_three")).collect(),
+ )
def test_udf_registration_returns_udf_on_sql_context(self):
from pyspark import SQLContext
df = self.spark.range(10)
- # This is to check if a 'SQLContext.udf' can call its alias.
- sqlContext = SQLContext.getOrCreate(self.spark.sparkContext)
- add_four = sqlContext.udf.register("add_four", lambda x: x + 4,
IntegerType())
+ with self.temp_func("add_four"):
+ # This is to check if a 'SQLContext.udf' can call its alias.
+ sqlContext = SQLContext.getOrCreate(self.spark.sparkContext)
+ add_four = sqlContext.udf.register("add_four", lambda x: x + 4,
IntegerType())
- self.assertListEqual(
- df.selectExpr("add_four(id) AS plus_four").collect(),
- df.select(add_four("id").alias("plus_four")).collect(),
- )
+ self.assertListEqual(
+ df.selectExpr("add_four(id) AS plus_four").collect(),
+ df.select(add_four("id").alias("plus_four")).collect(),
+ )
@unittest.skipIf(not test_compiled, test_not_compiled_message) # type:
ignore
def test_register_java_function(self):
- self.spark.udf.registerJavaFunction(
- "javaStringLength", "test.org.apache.spark.sql.JavaStringLength",
IntegerType()
- )
- [value] = self.spark.sql("SELECT javaStringLength('test')").first()
- self.assertEqual(value, 4)
+ with self.temp_func("javaStringLength", "javaStringLength2",
"javaStringLength3"):
+ self.spark.udf.registerJavaFunction(
+ "javaStringLength",
"test.org.apache.spark.sql.JavaStringLength", IntegerType()
+ )
+ [value] = self.spark.sql("SELECT javaStringLength('test')").first()
+ self.assertEqual(value, 4)
- self.spark.udf.registerJavaFunction(
- "javaStringLength2", "test.org.apache.spark.sql.JavaStringLength"
- )
- [value] = self.spark.sql("SELECT javaStringLength2('test')").first()
- self.assertEqual(value, 4)
+ self.spark.udf.registerJavaFunction(
+ "javaStringLength2",
"test.org.apache.spark.sql.JavaStringLength"
+ )
+ [value] = self.spark.sql("SELECT
javaStringLength2('test')").first()
+ self.assertEqual(value, 4)
- self.spark.udf.registerJavaFunction(
- "javaStringLength3", "test.org.apache.spark.sql.JavaStringLength",
"integer"
- )
- [value] = self.spark.sql("SELECT javaStringLength3('test')").first()
- self.assertEqual(value, 4)
+ self.spark.udf.registerJavaFunction(
+ "javaStringLength3",
"test.org.apache.spark.sql.JavaStringLength", "integer"
+ )
+ [value] = self.spark.sql("SELECT
javaStringLength3('test')").first()
+ self.assertEqual(value, 4)
@unittest.skipIf(not test_compiled, test_not_compiled_message) # type:
ignore
def test_register_java_udaf(self):
- self.spark.udf.registerJavaUDAF("javaUDAF",
"test.org.apache.spark.sql.MyDoubleAvg")
- df = self.spark.createDataFrame([(1, "a"), (2, "b"), (3, "a")], ["id",
"name"])
- df.createOrReplaceTempView("df")
- row = self.spark.sql(
- "SELECT name, javaUDAF(id) as avg from df group by name order by
name desc"
- ).first()
- self.assertEqual(row.asDict(), Row(name="b", avg=102.0).asDict())
+ with self.temp_func("javaUDAF"):
+ self.spark.udf.registerJavaUDAF("javaUDAF",
"test.org.apache.spark.sql.MyDoubleAvg")
+ df = self.spark.createDataFrame([(1, "a"), (2, "b"), (3, "a")],
["id", "name"])
+ df.createOrReplaceTempView("df")
+ row = self.spark.sql(
+ "SELECT name, javaUDAF(id) as avg from df group by name order
by name desc"
+ ).first()
+ self.assertEqual(row.asDict(), Row(name="b", avg=102.0).asDict())
def test_err_udf_registration(self):
- with self.quiet():
- self.check_err_udf_registration()
-
- def check_err_udf_registration(self):
- with self.assertRaises(PySparkTypeError) as pe:
- self.spark.udf.register("f", UserDefinedFunction("x",
StringType()), "int")
-
- self.check_error(
- exception=pe.exception,
- errorClass="NOT_CALLABLE",
- messageParameters={"arg_name": "func", "arg_type": "str"},
- )
+ with self.quiet(), self.temp_func("f"):
+ with self.assertRaises(PySparkTypeError) as pe:
+ self.spark.udf.register("f", UserDefinedFunction("x",
StringType()), "int")
+
+ self.check_error(
+ exception=pe.exception,
+ errorClass="NOT_CALLABLE",
+ messageParameters={"arg_name": "func", "arg_type": "str"},
+ )
def test_non_existed_udf(self):
spark = self.spark
@@ -1069,107 +1083,111 @@ class BaseUDFTestsMixin(object):
def test_udf(a, b):
return a + 10 * b
- self.spark.udf.register("test_udf", test_udf)
-
- for i, df in enumerate(
- [
- self.spark.range(2).select(test_udf(col("id"), b=col("id") *
10)),
- self.spark.range(2).select(test_udf(a=col("id"), b=col("id") *
10)),
- self.spark.range(2).select(test_udf(b=col("id") * 10,
a=col("id"))),
- self.spark.sql("SELECT test_udf(id, b => id * 10) FROM
range(2)"),
- self.spark.sql("SELECT test_udf(a => id, b => id * 10) FROM
range(2)"),
- self.spark.sql("SELECT test_udf(b => id * 10, a => id) FROM
range(2)"),
- ]
- ):
- with self.subTest(query_no=i):
- assertDataFrameEqual(df, [Row(0), Row(101)])
+ with self.temp_func("test_udf"):
+ self.spark.udf.register("test_udf", test_udf)
+
+ for i, df in enumerate(
+ [
+ self.spark.range(2).select(test_udf(col("id"), b=col("id")
* 10)),
+ self.spark.range(2).select(test_udf(a=col("id"),
b=col("id") * 10)),
+ self.spark.range(2).select(test_udf(b=col("id") * 10,
a=col("id"))),
+ self.spark.sql("SELECT test_udf(id, b => id * 10) FROM
range(2)"),
+ self.spark.sql("SELECT test_udf(a => id, b => id * 10)
FROM range(2)"),
+ self.spark.sql("SELECT test_udf(b => id * 10, a => id)
FROM range(2)"),
+ ]
+ ):
+ with self.subTest(query_no=i):
+ assertDataFrameEqual(df, [Row(0), Row(101)])
def test_named_arguments_negative(self):
@udf("int")
def test_udf(a, b):
return a + b
- self.spark.udf.register("test_udf", test_udf)
+ with self.temp_func("test_udf"):
+ self.spark.udf.register("test_udf", test_udf)
- with self.assertRaisesRegex(
- AnalysisException,
-
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
- ):
- self.spark.sql("SELECT test_udf(a => id, a => id * 10) FROM
range(2)").show()
+ with self.assertRaisesRegex(
+ AnalysisException,
+
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
+ ):
+ self.spark.sql("SELECT test_udf(a => id, a => id * 10) FROM
range(2)").show()
- with self.assertRaisesRegex(AnalysisException,
"UNEXPECTED_POSITIONAL_ARGUMENT"):
- self.spark.sql("SELECT test_udf(a => id, id * 10) FROM
range(2)").show()
+ with self.assertRaisesRegex(AnalysisException,
"UNEXPECTED_POSITIONAL_ARGUMENT"):
+ self.spark.sql("SELECT test_udf(a => id, id * 10) FROM
range(2)").show()
- with self.assertRaisesRegex(
- PythonException, r"test_udf\(\) got an unexpected keyword argument
'c'"
- ):
- self.spark.sql("SELECT test_udf(c => 'x') FROM range(2)").show()
+ with self.assertRaisesRegex(
+ PythonException, r"test_udf\(\) got an unexpected keyword
argument 'c'"
+ ):
+ self.spark.sql("SELECT test_udf(c => 'x') FROM
range(2)").show()
- with self.assertRaisesRegex(
- PythonException, r"test_udf\(\) got multiple values for argument
'a'"
- ):
- self.spark.sql("SELECT test_udf(id, a => id * 10) FROM
range(2)").show()
+ with self.assertRaisesRegex(
+ PythonException, r"test_udf\(\) got multiple values for
argument 'a'"
+ ):
+ self.spark.sql("SELECT test_udf(id, a => id * 10) FROM
range(2)").show()
def test_kwargs(self):
@udf("int")
def test_udf(**kwargs):
return kwargs["a"] + 10 * kwargs["b"]
- self.spark.udf.register("test_udf", test_udf)
+ with self.temp_func("test_udf"):
+ self.spark.udf.register("test_udf", test_udf)
- for i, df in enumerate(
- [
- self.spark.range(2).select(test_udf(a=col("id"), b=col("id") *
10)),
- self.spark.range(2).select(test_udf(b=col("id") * 10,
a=col("id"))),
- self.spark.sql("SELECT test_udf(a => id, b => id * 10) FROM
range(2)"),
- self.spark.sql("SELECT test_udf(b => id * 10, a => id) FROM
range(2)"),
- ]
- ):
- with self.subTest(query_no=i):
- assertDataFrameEqual(df, [Row(0), Row(101)])
+ for i, df in enumerate(
+ [
+ self.spark.range(2).select(test_udf(a=col("id"),
b=col("id") * 10)),
+ self.spark.range(2).select(test_udf(b=col("id") * 10,
a=col("id"))),
+ self.spark.sql("SELECT test_udf(a => id, b => id * 10)
FROM range(2)"),
+ self.spark.sql("SELECT test_udf(b => id * 10, a => id)
FROM range(2)"),
+ ]
+ ):
+ with self.subTest(query_no=i):
+ assertDataFrameEqual(df, [Row(0), Row(101)])
- # negative
- with self.assertRaisesRegex(
- AnalysisException,
-
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
- ):
- self.spark.sql("SELECT test_udf(a => id, a => id * 10) FROM
range(2)").show()
+ # negative
+ with self.assertRaisesRegex(
+ AnalysisException,
+
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
+ ):
+ self.spark.sql("SELECT test_udf(a => id, a => id * 10) FROM
range(2)").show()
- with self.assertRaisesRegex(AnalysisException,
"UNEXPECTED_POSITIONAL_ARGUMENT"):
- self.spark.sql("SELECT test_udf(a => id, id * 10) FROM
range(2)").show()
+ with self.assertRaisesRegex(AnalysisException,
"UNEXPECTED_POSITIONAL_ARGUMENT"):
+ self.spark.sql("SELECT test_udf(a => id, id * 10) FROM
range(2)").show()
def test_named_arguments_and_defaults(self):
@udf("int")
def test_udf(a, b=0):
return a + 10 * b
- self.spark.udf.register("test_udf", test_udf)
+ with self.temp_func("test_udf"):
+ self.spark.udf.register("test_udf", test_udf)
- # without "b"
- for i, df in enumerate(
- [
- self.spark.range(2).select(test_udf(col("id"))),
- self.spark.range(2).select(test_udf(a=col("id"))),
- self.spark.sql("SELECT test_udf(id) FROM range(2)"),
- self.spark.sql("SELECT test_udf(a => id) FROM range(2)"),
- ]
- ):
- with self.subTest(with_b=False, query_no=i):
- assertDataFrameEqual(df, [Row(0), Row(1)])
-
- # with "b"
- for i, df in enumerate(
- [
- self.spark.range(2).select(test_udf(col("id"), b=col("id") *
10)),
- self.spark.range(2).select(test_udf(a=col("id"), b=col("id") *
10)),
- self.spark.range(2).select(test_udf(b=col("id") * 10,
a=col("id"))),
- self.spark.sql("SELECT test_udf(id, b => id * 10) FROM
range(2)"),
- self.spark.sql("SELECT test_udf(a => id, b => id * 10) FROM
range(2)"),
- self.spark.sql("SELECT test_udf(b => id * 10, a => id) FROM
range(2)"),
- ]
- ):
- with self.subTest(with_b=True, query_no=i):
- assertDataFrameEqual(df, [Row(0), Row(101)])
+ # without "b"
+ for i, df in enumerate(
+ [
+ self.spark.range(2).select(test_udf(col("id"))),
+ self.spark.range(2).select(test_udf(a=col("id"))),
+ self.spark.sql("SELECT test_udf(id) FROM range(2)"),
+ self.spark.sql("SELECT test_udf(a => id) FROM range(2)"),
+ ]
+ ):
+ with self.subTest(with_b=False, query_no=i):
+ assertDataFrameEqual(df, [Row(0), Row(1)])
+
+ # with "b"
+ for i, df in enumerate(
+ [
+ self.spark.range(2).select(test_udf(col("id"), b=col("id")
* 10)),
+ self.spark.range(2).select(test_udf(a=col("id"),
b=col("id") * 10)),
+ self.spark.range(2).select(test_udf(b=col("id") * 10,
a=col("id"))),
+ self.spark.sql("SELECT test_udf(id, b => id * 10) FROM
range(2)"),
+ self.spark.sql("SELECT test_udf(a => id, b => id * 10)
FROM range(2)"),
+ self.spark.sql("SELECT test_udf(b => id * 10, a => id)
FROM range(2)"),
+ ]
+ ):
+ with self.subTest(with_b=True, query_no=i):
+ assertDataFrameEqual(df, [Row(0), Row(101)])
def test_num_arguments(self):
@udf("long")
@@ -1544,28 +1562,36 @@ class UDFTests(BaseUDFTestsMixin, ReusedSQLTestCase):
# various batch size and see whether the query runs successfully, and the
output is
# consistent across different batch sizes.
def test_udf_with_various_batch_size(self):
- self.spark.catalog.registerFunction("twoArgs", lambda x, y: len(x) +
y, IntegerType())
- for batch_size in [1, 33, 1000, 2000]:
- with
self.sql_conf({"spark.sql.execution.python.udf.maxRecordsPerBatch":
batch_size}):
- df = self.spark.range(1000).selectExpr("twoArgs('test', id) AS
ret").orderBy("ret")
- rets = [x["ret"] for x in df.collect()]
- self.assertEqual(rets, list(range(4, 1004)))
+ with self.temp_func("twoArgs"):
+ self.spark.catalog.registerFunction("twoArgs", lambda x, y: len(x)
+ y, IntegerType())
+ for batch_size in [1, 33, 1000, 2000]:
+ with self.sql_conf(
+ {"spark.sql.execution.python.udf.maxRecordsPerBatch":
batch_size}
+ ):
+ df = (
+ self.spark.range(1000)
+ .selectExpr("twoArgs('test', id) AS ret")
+ .orderBy("ret")
+ )
+ rets = [x["ret"] for x in df.collect()]
+ self.assertEqual(rets, list(range(4, 1004)))
# We cannot check whether the buffer size is effective or not. We just run
the query with
# various buffer size and see whether the query runs successfully, and the
output is
# consistent across different batch sizes.
def test_udf_with_various_buffer_size(self):
- self.spark.catalog.registerFunction("twoArgs", lambda x, y: len(x) +
y, IntegerType())
- for batch_size in [1, 33, 10000]:
- with self.sql_conf({"spark.sql.execution.python.udf.buffer.size":
batch_size}):
- df = (
- self.spark.range(1000)
- .repartition(1)
- .selectExpr("twoArgs('test', id) AS ret")
- .orderBy("ret")
- )
- rets = [x["ret"] for x in df.collect()]
- self.assertEqual(rets, list(range(4, 1004)))
+ with self.temp_func("twoArgs"):
+ self.spark.catalog.registerFunction("twoArgs", lambda x, y: len(x)
+ y, IntegerType())
+ for batch_size in [1, 33, 10000]:
+ with
self.sql_conf({"spark.sql.execution.python.udf.buffer.size": batch_size}):
+ df = (
+ self.spark.range(1000)
+ .repartition(1)
+ .selectExpr("twoArgs('test', id) AS ret")
+ .orderBy("ret")
+ )
+ rets = [x["ret"] for x in df.collect()]
+ self.assertEqual(rets, list(range(4, 1004)))
class UDFInitializationTests(unittest.TestCase):
diff --git a/python/pyspark/sql/tests/test_udf_profiler.py
b/python/pyspark/sql/tests/test_udf_profiler.py
index de35532285df..4e8f722c22cb 100644
--- a/python/pyspark/sql/tests/test_udf_profiler.py
+++ b/python/pyspark/sql/tests/test_udf_profiler.py
@@ -263,15 +263,16 @@ class UDFProfiler2TestsMixin:
def add1(x):
return x + 1
- self.spark.udf.register("add1", add1)
+ with self.temp_func("add1"):
+ self.spark.udf.register("add1", add1)
- with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}):
- self.spark.sql("SELECT id, add1(id) add1 FROM range(10)").collect()
+ with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}):
+ self.spark.sql("SELECT id, add1(id) add1 FROM
range(10)").collect()
- self.assertEqual(1, len(self.profile_results),
str(self.profile_results.keys()))
+ self.assertEqual(1, len(self.profile_results),
str(self.profile_results.keys()))
- for id in self.profile_results:
- self.assert_udf_profile_present(udf_id=id,
expected_line_count_prefix=10)
+ for id in self.profile_results:
+ self.assert_udf_profile_present(udf_id=id,
expected_line_count_prefix=10)
@unittest.skipIf(
not have_pandas or not have_pyarrow,
diff --git a/python/pyspark/sql/tests/test_unified_udf.py
b/python/pyspark/sql/tests/test_unified_udf.py
index 3c105637c791..2d3446bd0b5b 100644
--- a/python/pyspark/sql/tests/test_unified_udf.py
+++ b/python/pyspark/sql/tests/test_unified_udf.py
@@ -53,11 +53,10 @@ class UnifiedUDFTestsMixin:
result1 = df.select(pd_add1("id").alias("res")).collect()
self.assertEqual(result1, expected)
- self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS pd_add1")
- self.spark.udf.register("pd_add1", pd_add1)
- result2 = self.spark.sql("SELECT pd_add1(id) AS res FROM range(0,
10)").collect()
- self.assertEqual(result2, expected)
- self.spark.sql("DROP TEMPORARY FUNCTION pd_add1")
+ with self.temp_func("pd_add1"):
+ self.spark.udf.register("pd_add1", pd_add1)
+ result2 = self.spark.sql("SELECT pd_add1(id) AS res FROM range(0,
10)").collect()
+ self.assertEqual(result2, expected)
def test_scalar_pandas_udf_II(self):
import pandas as pd
@@ -76,11 +75,10 @@ class UnifiedUDFTestsMixin:
result1 = df.select(pd_add("id", "id").alias("res")).collect()
self.assertEqual(result1, expected)
- self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS pd_add")
- self.spark.udf.register("pd_add", pd_add)
- result2 = self.spark.sql("SELECT pd_add(id, id) AS res FROM range(0,
10)").collect()
- self.assertEqual(result2, expected)
- self.spark.sql("DROP TEMPORARY FUNCTION pd_add")
+ with self.temp_func("pd_add"):
+ self.spark.udf.register("pd_add", pd_add)
+ result2 = self.spark.sql("SELECT pd_add(id, id) AS res FROM
range(0, 10)").collect()
+ self.assertEqual(result2, expected)
def test_scalar_pandas_iter_udf(self):
import pandas as pd
@@ -99,11 +97,10 @@ class UnifiedUDFTestsMixin:
result1 = df.select(pd_add1_iter("id").alias("res")).collect()
self.assertEqual(result1, expected)
- self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS pd_add1_iter")
- self.spark.udf.register("pd_add1_iter", pd_add1_iter)
- result2 = self.spark.sql("SELECT pd_add1_iter(id) AS res FROM range(0,
10)").collect()
- self.assertEqual(result2, expected)
- self.spark.sql("DROP TEMPORARY FUNCTION pd_add1_iter")
+ with self.temp_func("pd_add1_iter"):
+ self.spark.udf.register("pd_add1_iter", pd_add1_iter)
+ result2 = self.spark.sql("SELECT pd_add1_iter(id) AS res FROM
range(0, 10)").collect()
+ self.assertEqual(result2, expected)
def test_scalar_pandas_iter_udf_II(self):
import pandas as pd
@@ -123,11 +120,12 @@ class UnifiedUDFTestsMixin:
result1 = df.select(pd_add_iter("id", "id").alias("res")).collect()
self.assertEqual(result1, expected)
- self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS pd_add_iter")
- self.spark.udf.register("pd_add_iter", pd_add_iter)
- result2 = self.spark.sql("SELECT pd_add_iter(id, id) AS res FROM
range(0, 10)").collect()
- self.assertEqual(result2, expected)
- self.spark.sql("DROP TEMPORARY FUNCTION pd_add_iter")
+ with self.temp_func("pd_add_iter"):
+ self.spark.udf.register("pd_add_iter", pd_add_iter)
+ result2 = self.spark.sql(
+ "SELECT pd_add_iter(id, id) AS res FROM range(0, 10)"
+ ).collect()
+ self.assertEqual(result2, expected)
def test_grouped_agg_pandas_udf(self):
import pandas as pd
@@ -145,11 +143,10 @@ class UnifiedUDFTestsMixin:
result1 = df.select(pd_max("id").alias("res")).collect()
self.assertEqual(result1, expected)
- self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS pd_max")
- self.spark.udf.register("pd_max", pd_max)
- result2 = self.spark.sql("SELECT pd_max(id) AS res FROM range(0,
10)").collect()
- self.assertEqual(result2, expected)
- self.spark.sql("DROP TEMPORARY FUNCTION pd_max")
+ with self.temp_func("pd_max"):
+ self.spark.udf.register("pd_max", pd_max)
+ result2 = self.spark.sql("SELECT pd_max(id) AS res FROM range(0,
10)").collect()
+ self.assertEqual(result2, expected)
def test_window_agg_pandas_udf(self):
import pandas as pd
@@ -180,9 +177,8 @@ class UnifiedUDFTestsMixin:
result1 = df.withColumn("res", pd_win_max("v").over(w)).collect()
self.assertEqual(result1, expected)
- with self.tempView("pd_tbl"):
+ with self.tempView("pd_tbl"), self.temp_func("pd_win_max"):
df.createOrReplaceTempView("pd_tbl")
- self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS pd_win_max")
self.spark.udf.register("pd_win_max", pd_win_max)
result2 = self.spark.sql(
@@ -195,7 +191,6 @@ class UnifiedUDFTestsMixin:
"""
).collect()
self.assertEqual(result2, expected)
- self.spark.sql("DROP TEMPORARY FUNCTION pd_win_max")
def test_scalar_arrow_udf(self):
import pyarrow as pa
@@ -213,11 +208,10 @@ class UnifiedUDFTestsMixin:
result1 = df.select(pa_add1("id").alias("res")).collect()
self.assertEqual(result1, expected)
- self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS pa_add1")
- self.spark.udf.register("pa_add1", pa_add1)
- result2 = self.spark.sql("SELECT pa_add1(id) AS res FROM range(0,
10)").collect()
- self.assertEqual(result2, expected)
- self.spark.sql("DROP TEMPORARY FUNCTION pa_add1")
+ with self.temp_func("pa_add1"):
+ self.spark.udf.register("pa_add1", pa_add1)
+ result2 = self.spark.sql("SELECT pa_add1(id) AS res FROM range(0,
10)").collect()
+ self.assertEqual(result2, expected)
def test_scalar_arrow_udf_II(self):
import pyarrow as pa
@@ -236,11 +230,10 @@ class UnifiedUDFTestsMixin:
result1 = df.select(pa_add("id", "id").alias("res")).collect()
self.assertEqual(result1, expected)
- self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS pa_add")
- self.spark.udf.register("pa_add", pa_add)
- result2 = self.spark.sql("SELECT pa_add(id, id) AS res FROM range(0,
10)").collect()
- self.assertEqual(result2, expected)
- self.spark.sql("DROP TEMPORARY FUNCTION pa_add")
+ with self.temp_func("pa_add"):
+ self.spark.udf.register("pa_add", pa_add)
+ result2 = self.spark.sql("SELECT pa_add(id, id) AS res FROM
range(0, 10)").collect()
+ self.assertEqual(result2, expected)
def test_scalar_arrow_iter_udf(self):
import pyarrow as pa
@@ -259,11 +252,10 @@ class UnifiedUDFTestsMixin:
result1 = df.select(pa_add1_iter("id").alias("res")).collect()
self.assertEqual(result1, expected)
- self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS pa_add1_iter")
- self.spark.udf.register("pa_add1_iter", pa_add1_iter)
- result2 = self.spark.sql("SELECT pa_add1_iter(id) AS res FROM range(0,
10)").collect()
- self.assertEqual(result2, expected)
- self.spark.sql("DROP TEMPORARY FUNCTION pa_add1_iter")
+ with self.temp_func("pa_add1_iter"):
+ self.spark.udf.register("pa_add1_iter", pa_add1_iter)
+ result2 = self.spark.sql("SELECT pa_add1_iter(id) AS res FROM
range(0, 10)").collect()
+ self.assertEqual(result2, expected)
def test_scalar_arrow_iter_udf_II(self):
import pyarrow as pa
@@ -283,11 +275,12 @@ class UnifiedUDFTestsMixin:
result1 = df.select(pa_add_iter("id", "id").alias("res")).collect()
self.assertEqual(result1, expected)
- self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS pa_add_iter")
- self.spark.udf.register("pa_add_iter", pa_add_iter)
- result2 = self.spark.sql("SELECT pa_add_iter(id, id) AS res FROM
range(0, 10)").collect()
- self.assertEqual(result2, expected)
- self.spark.sql("DROP TEMPORARY FUNCTION pa_add_iter")
+ with self.temp_func("pa_add_iter"):
+ self.spark.udf.register("pa_add_iter", pa_add_iter)
+ result2 = self.spark.sql(
+ "SELECT pa_add_iter(id, id) AS res FROM range(0, 10)"
+ ).collect()
+ self.assertEqual(result2, expected)
def test_grouped_agg_arrow_udf(self):
import pyarrow as pa
@@ -305,11 +298,10 @@ class UnifiedUDFTestsMixin:
result1 = df.select(pa_max("id").alias("res")).collect()
self.assertEqual(result1, expected)
- self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS pa_max")
- self.spark.udf.register("pa_max", pa_max)
- result2 = self.spark.sql("SELECT pa_max(id) AS res FROM range(0,
10)").collect()
- self.assertEqual(result2, expected)
- self.spark.sql("DROP TEMPORARY FUNCTION pa_max")
+ with self.temp_func("pa_max"):
+ self.spark.udf.register("pa_max", pa_max)
+ result2 = self.spark.sql("SELECT pa_max(id) AS res FROM range(0,
10)").collect()
+ self.assertEqual(result2, expected)
def test_window_agg_arrow_udf(self):
import pyarrow as pa
@@ -340,9 +332,8 @@ class UnifiedUDFTestsMixin:
result1 = df.withColumn("mean_v", pa_win_max("v").over(w)).collect()
self.assertEqual(result1, expected)
- with self.tempView("pa_tbl"):
+ with self.tempView("pa_tbl"), self.temp_func("pa_win_max"):
df.createOrReplaceTempView("pa_tbl")
- self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS pa_win_max")
self.spark.udf.register("pa_win_max", pa_win_max)
result2 = self.spark.sql(
@@ -355,7 +346,6 @@ class UnifiedUDFTestsMixin:
"""
).collect()
self.assertEqual(result2, expected)
- self.spark.sql("DROP TEMPORARY FUNCTION pa_win_max")
def test_regular_python_udf(self):
import pandas as pd
diff --git a/python/pyspark/testing/sqlutils.py
b/python/pyspark/testing/sqlutils.py
index 98d04e7d5b1a..645bf1f2ea80 100644
--- a/python/pyspark/testing/sqlutils.py
+++ b/python/pyspark/testing/sqlutils.py
@@ -147,6 +147,20 @@ class SQLTestUtils:
for v in views:
self.spark.catalog.dropTempView(v)
+ @contextmanager
+ def temp_func(self, *functions):
+ """
+ A convenient context manager to test with some specific temporary
functions.
+ This drops the temporary functions if it exists.
+ """
+ assert hasattr(self, "spark"), "it should have 'spark' attribute,
having a spark session."
+
+ try:
+ yield
+ finally:
+ for f in functions:
+ self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS %s" % f)
+
@contextmanager
def function(self, *functions):
"""
diff --git a/python/pyspark/tests/test_memory_profiler.py
b/python/pyspark/tests/test_memory_profiler.py
index 144442b5a48f..df9d63c5260f 100644
--- a/python/pyspark/tests/test_memory_profiler.py
+++ b/python/pyspark/tests/test_memory_profiler.py
@@ -302,15 +302,16 @@ class MemoryProfiler2TestsMixin:
def add1(x):
return x + 1
- self.spark.udf.register("add1", add1)
+ with self.temp_func("add1"):
+ self.spark.udf.register("add1", add1)
- with self.sql_conf({"spark.sql.pyspark.udf.profiler": "memory"}):
- self.spark.sql("SELECT id, add1(id) add1 FROM range(10)").collect()
+ with self.sql_conf({"spark.sql.pyspark.udf.profiler": "memory"}):
+ self.spark.sql("SELECT id, add1(id) add1 FROM
range(10)").collect()
- self.assertEqual(1, len(self.profile_results),
str(self.profile_results.keys()))
+ self.assertEqual(1, len(self.profile_results),
str(self.profile_results.keys()))
- for id in self.profile_results:
- self.assert_udf_memory_profile_present(udf_id=id)
+ for id in self.profile_results:
+ self.assert_udf_memory_profile_present(udf_id=id)
@unittest.skipIf(
not have_pandas or not have_pyarrow,
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]