This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 9f325422459d [SPARK-53963][PYTHON][TESTS] Drop temporary functions in 
regular UDF tests
9f325422459d is described below

commit 9f325422459d9d0df01f3a487d49d75a3944ace2
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Tue Oct 21 16:17:57 2025 +0800

    [SPARK-53963][PYTHON][TESTS] Drop temporary functions in regular UDF tests
    
    ### What changes were proposed in this pull request?
    Drop temporary functions in regular UDF tests:
    1, introduce `temp_func` in testing utils;
    2, apply `temp_func` in regular UDF tests (there are too many places to 
change, will do it for pandas/arrow udf separately)
    3, rename some temp function names, e.g. `double` -> `double_int` to make 
it able to drop
    
    ### Why are the changes needed?
    to avoid polluting the testing envs
    
    ### Does this PR introduce _any_ user-facing change?
    no, test-only
    
    ### How was this patch tested?
    ci
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #52674 from zhengruifeng/with_temp_func_udf.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 python/pyspark/sql/tests/test_types.py        |   7 +-
 python/pyspark/sql/tests/test_udf.py          | 434 ++++++++++++++------------
 python/pyspark/sql/tests/test_udf_profiler.py |  13 +-
 python/pyspark/sql/tests/test_unified_udf.py  | 102 +++---
 python/pyspark/testing/sqlutils.py            |  14 +
 python/pyspark/tests/test_memory_profiler.py  |  13 +-
 6 files changed, 308 insertions(+), 275 deletions(-)

diff --git a/python/pyspark/sql/tests/test_types.py 
b/python/pyspark/sql/tests/test_types.py
index 319ff92dd362..be0ff14965e0 100644
--- a/python/pyspark/sql/tests/test_types.py
+++ b/python/pyspark/sql/tests/test_types.py
@@ -1002,9 +1002,10 @@ class TypesTestsMixin:
             if x > 0:
                 return PythonOnlyPoint(float(x), float(x))
 
-        self.spark.catalog.registerFunction("udf", myudf, PythonOnlyUDT())
-        rows = [r[0] for r in df.selectExpr("udf(id)").take(2)]
-        self.assertEqual(rows, [None, PythonOnlyPoint(1, 1)])
+        with self.temp_func("udf"):
+            self.spark.catalog.registerFunction("udf", myudf, PythonOnlyUDT())
+            rows = [r[0] for r in df.selectExpr("udf(id)").take(2)]
+            self.assertEqual(rows, [None, PythonOnlyPoint(1, 1)])
 
     def test_infer_schema_with_udt(self):
         row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
diff --git a/python/pyspark/sql/tests/test_udf.py 
b/python/pyspark/sql/tests/test_udf.py
index d148d3a2c64a..b1fb42ad11ec 100644
--- a/python/pyspark/sql/tests/test_udf.py
+++ b/python/pyspark/sql/tests/test_udf.py
@@ -82,18 +82,20 @@ class BaseUDFTestsMixin(object):
         self.assertEqual(res.agg({"plus_four": "sum"}).collect()[0][0], 85)
 
     def test_udf(self):
-        self.spark.catalog.registerFunction("twoArgs", lambda x, y: len(x) + 
y, IntegerType())
-        [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
-        self.assertEqual(row[0], 5)
+        with self.temp_func("twoArgs"):
+            self.spark.catalog.registerFunction("twoArgs", lambda x, y: len(x) 
+ y, IntegerType())
+            [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
+            self.assertEqual(row[0], 5)
 
     def test_udf_on_sql_context(self):
         from pyspark import SQLContext
 
-        # This is to check if a deprecated 'SQLContext.registerFunction' can 
call its alias.
-        sqlContext = SQLContext.getOrCreate(self.spark.sparkContext)
-        sqlContext.registerFunction("oneArg", lambda x: len(x), IntegerType())
-        [row] = sqlContext.sql("SELECT oneArg('test')").collect()
-        self.assertEqual(row[0], 4)
+        with self.temp_func("oneArg"):
+            # This is to check if a deprecated 'SQLContext.registerFunction' 
can call its alias.
+            sqlContext = SQLContext.getOrCreate(self.spark.sparkContext)
+            sqlContext.registerFunction("oneArg", lambda x: len(x), 
IntegerType())
+            [row] = sqlContext.sql("SELECT oneArg('test')").collect()
+            self.assertEqual(row[0], 4)
 
     def test_udf2(self):
         with self.tempView("test"):
@@ -103,20 +105,22 @@ class BaseUDFTestsMixin(object):
             self.assertEqual(4, res[0])
 
     def test_udf3(self):
-        two_args = self.spark.catalog.registerFunction(
-            "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y)
-        )
-        self.assertEqual(two_args.deterministic, True)
-        [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
-        self.assertEqual(row[0], "5")
+        with self.temp_func("twoArgs"):
+            two_args = self.spark.catalog.registerFunction(
+                "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y)
+            )
+            self.assertEqual(two_args.deterministic, True)
+            [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
+            self.assertEqual(row[0], "5")
 
     def test_udf_registration_return_type_none(self):
-        two_args = self.spark.catalog.registerFunction(
-            "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y, 
"integer"), None
-        )
-        self.assertEqual(two_args.deterministic, True)
-        [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
-        self.assertEqual(row[0], 5)
+        with self.temp_func("twoArgs"):
+            two_args = self.spark.catalog.registerFunction(
+                "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y, 
"integer"), None
+            )
+            self.assertEqual(two_args.deterministic, True)
+            [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
+            self.assertEqual(row[0], 5)
 
     def test_udf_registration_return_type_not_none(self):
         with self.quiet():
@@ -149,21 +153,22 @@ class BaseUDFTestsMixin(object):
     def test_nondeterministic_udf2(self):
         import random
 
-        random_udf = udf(lambda: random.randint(6, 6), 
IntegerType()).asNondeterministic()
-        self.assertEqual(random_udf.deterministic, False)
-        random_udf1 = self.spark.catalog.registerFunction("randInt", 
random_udf)
-        self.assertEqual(random_udf1.deterministic, False)
-        [row] = self.spark.sql("SELECT randInt()").collect()
-        self.assertEqual(row[0], 6)
-        [row] = self.spark.range(1).select(random_udf1()).collect()
-        self.assertEqual(row[0], 6)
-        [row] = self.spark.range(1).select(random_udf()).collect()
-        self.assertEqual(row[0], 6)
-        # render_doc() reproduces the help() exception without printing output
-        pydoc.render_doc(udf(lambda: random.randint(6, 6), IntegerType()))
-        pydoc.render_doc(random_udf)
-        pydoc.render_doc(random_udf1)
-        pydoc.render_doc(udf(lambda x: x).asNondeterministic)
+        with self.temp_func("randInt"):
+            random_udf = udf(lambda: random.randint(6, 6), 
IntegerType()).asNondeterministic()
+            self.assertEqual(random_udf.deterministic, False)
+            random_udf1 = self.spark.catalog.registerFunction("randInt", 
random_udf)
+            self.assertEqual(random_udf1.deterministic, False)
+            [row] = self.spark.sql("SELECT randInt()").collect()
+            self.assertEqual(row[0], 6)
+            [row] = self.spark.range(1).select(random_udf1()).collect()
+            self.assertEqual(row[0], 6)
+            [row] = self.spark.range(1).select(random_udf()).collect()
+            self.assertEqual(row[0], 6)
+            # render_doc() reproduces the help() exception without printing 
output
+            pydoc.render_doc(udf(lambda: random.randint(6, 6), IntegerType()))
+            pydoc.render_doc(random_udf)
+            pydoc.render_doc(random_udf1)
+            pydoc.render_doc(udf(lambda x: x).asNondeterministic)
 
     def test_nondeterministic_udf3(self):
         # regression test for SPARK-23233
@@ -194,32 +199,36 @@ class BaseUDFTestsMixin(object):
             df.agg(sum(udf_random_col())).collect()
 
     def test_chained_udf(self):
-        self.spark.catalog.registerFunction("double_int", lambda x: x + x, 
IntegerType())
-        try:
+        with self.temp_func("double_int"):
+            self.spark.catalog.registerFunction("double_int", lambda x: x + x, 
IntegerType())
             [row] = self.spark.sql("SELECT double_int(1)").collect()
             self.assertEqual(row[0], 2)
             [row] = self.spark.sql("SELECT 
double_int(double_int(1))").collect()
             self.assertEqual(row[0], 4)
             [row] = self.spark.sql("SELECT double_int(double_int(1) + 
1)").collect()
             self.assertEqual(row[0], 6)
-        finally:
-            self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS double_int")
 
     def test_single_udf_with_repeated_argument(self):
         # regression test for SPARK-20685
-        self.spark.catalog.registerFunction("add", lambda x, y: x + y, 
IntegerType())
-        row = self.spark.sql("SELECT add(1, 1)").first()
-        self.assertEqual(tuple(row), (2,))
+        with self.temp_func("add_int"):
+            self.spark.catalog.registerFunction("add_int", lambda x, y: x + y, 
IntegerType())
+            row = self.spark.sql("SELECT add_int(1, 1)").first()
+            self.assertEqual(tuple(row), (2,))
 
     def test_multiple_udfs(self):
-        self.spark.catalog.registerFunction("double", lambda x: x * 2, 
IntegerType())
-        [row] = self.spark.sql("SELECT double(1), double(2)").collect()
-        self.assertEqual(tuple(row), (2, 4))
-        [row] = self.spark.sql("SELECT double(double(1)), double(double(2) + 
2)").collect()
-        self.assertEqual(tuple(row), (4, 12))
-        self.spark.catalog.registerFunction("add", lambda x, y: x + y, 
IntegerType())
-        [row] = self.spark.sql("SELECT double(add(1, 2)), add(double(2), 
1)").collect()
-        self.assertEqual(tuple(row), (6, 5))
+        with self.temp_func("double_int"):
+            self.spark.catalog.registerFunction("double_int", lambda x: x * 2, 
IntegerType())
+            [row] = self.spark.sql("SELECT double_int(1), 
double_int(2)").collect()
+            self.assertEqual(tuple(row), (2, 4))
+            [row] = self.spark.sql(
+                "SELECT double_int(double_int(1)), double_int(double_int(2) + 
2)"
+            ).collect()
+            self.assertEqual(tuple(row), (4, 12))
+            self.spark.catalog.registerFunction("add_int", lambda x, y: x + y, 
IntegerType())
+            [row] = self.spark.sql(
+                "SELECT double_int(add_int(1, 2)), add_int(double_int(2), 1)"
+            ).collect()
+            self.assertEqual(tuple(row), (6, 5))
 
     def test_udf_in_filter_on_top_of_outer_join(self):
         left = self.spark.createDataFrame([Row(a=1)])
@@ -303,12 +312,13 @@ class BaseUDFTestsMixin(object):
         assertDataFrameEqual(df, [Row(a=1, a1=1, a2=1, b=1, b1=1, b2=1)])
 
     def test_udf_without_arguments(self):
-        self.spark.catalog.registerFunction("foo", lambda: "bar")
-        [row] = self.spark.sql("SELECT foo()").collect()
-        self.assertEqual(row[0], "bar")
+        with self.temp_func("foo"):
+            self.spark.catalog.registerFunction("foo", lambda: "bar")
+            [row] = self.spark.sql("SELECT foo()").collect()
+            self.assertEqual(row[0], "bar")
 
     def test_udf_with_array_type(self):
-        with self.tempView("test"):
+        with self.tempView("test"), self.temp_func("copylist", "maplen"):
             self.spark.createDataFrame(
                 [
                     ([0, 1, 2], {"key": [0, 1, 2, 3, 4]}),
@@ -324,13 +334,14 @@ class BaseUDFTestsMixin(object):
             self.assertEqual(1, l2)
 
     def test_broadcast_in_udf(self):
-        bar = {"a": "aa", "b": "bb", "c": "abc"}
-        foo = self.sc.broadcast(bar)
-        self.spark.catalog.registerFunction("MYUDF", lambda x: foo.value[x] if 
x else "")
-        [res] = self.spark.sql("SELECT MYUDF('c')").collect()
-        self.assertEqual("abc", res[0])
-        [res] = self.spark.sql("SELECT MYUDF('')").collect()
-        self.assertEqual("", res[0])
+        with self.temp_func("MYUDF"):
+            bar = {"a": "aa", "b": "bb", "c": "abc"}
+            foo = self.sc.broadcast(bar)
+            self.spark.catalog.registerFunction("MYUDF", lambda x: 
foo.value[x] if x else "")
+            [res] = self.spark.sql("SELECT MYUDF('c')").collect()
+            self.assertEqual("abc", res[0])
+            [res] = self.spark.sql("SELECT MYUDF('')").collect()
+            self.assertEqual("", res[0])
 
     def test_udf_with_filter_function(self):
         df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, 
"2")], ["key", "value"])
@@ -515,75 +526,78 @@ class BaseUDFTestsMixin(object):
 
     def test_udf_registration_returns_udf(self):
         df = self.spark.range(10)
-        add_three = self.spark.udf.register("add_three", lambda x: x + 3, 
IntegerType())
-        self.assertListEqual(
-            df.selectExpr("add_three(id) AS plus_three").collect(),
-            df.select(add_three("id").alias("plus_three")).collect(),
-        )
 
-        add_three_str = self.spark.udf.register("add_three_str", lambda x: x + 
3)
-        self.assertListEqual(
-            df.selectExpr("add_three_str(id) AS plus_three").collect(),
-            df.select(add_three_str("id").alias("plus_three")).collect(),
-        )
+        with self.temp_func("add_three"):
+            add_three = self.spark.udf.register("add_three", lambda x: x + 3, 
IntegerType())
+            self.assertListEqual(
+                df.selectExpr("add_three(id) AS plus_three").collect(),
+                df.select(add_three("id").alias("plus_three")).collect(),
+            )
+
+        with self.temp_func("add_three_str"):
+            add_three_str = self.spark.udf.register("add_three_str", lambda x: 
x + 3)
+            self.assertListEqual(
+                df.selectExpr("add_three_str(id) AS plus_three").collect(),
+                df.select(add_three_str("id").alias("plus_three")).collect(),
+            )
 
     def test_udf_registration_returns_udf_on_sql_context(self):
         from pyspark import SQLContext
 
         df = self.spark.range(10)
 
-        # This is to check if a 'SQLContext.udf' can call its alias.
-        sqlContext = SQLContext.getOrCreate(self.spark.sparkContext)
-        add_four = sqlContext.udf.register("add_four", lambda x: x + 4, 
IntegerType())
+        with self.temp_func("add_four"):
+            # This is to check if a 'SQLContext.udf' can call its alias.
+            sqlContext = SQLContext.getOrCreate(self.spark.sparkContext)
+            add_four = sqlContext.udf.register("add_four", lambda x: x + 4, 
IntegerType())
 
-        self.assertListEqual(
-            df.selectExpr("add_four(id) AS plus_four").collect(),
-            df.select(add_four("id").alias("plus_four")).collect(),
-        )
+            self.assertListEqual(
+                df.selectExpr("add_four(id) AS plus_four").collect(),
+                df.select(add_four("id").alias("plus_four")).collect(),
+            )
 
     @unittest.skipIf(not test_compiled, test_not_compiled_message)  # type: 
ignore
     def test_register_java_function(self):
-        self.spark.udf.registerJavaFunction(
-            "javaStringLength", "test.org.apache.spark.sql.JavaStringLength", 
IntegerType()
-        )
-        [value] = self.spark.sql("SELECT javaStringLength('test')").first()
-        self.assertEqual(value, 4)
+        with self.temp_func("javaStringLength", "javaStringLength2", 
"javaStringLength3"):
+            self.spark.udf.registerJavaFunction(
+                "javaStringLength", 
"test.org.apache.spark.sql.JavaStringLength", IntegerType()
+            )
+            [value] = self.spark.sql("SELECT javaStringLength('test')").first()
+            self.assertEqual(value, 4)
 
-        self.spark.udf.registerJavaFunction(
-            "javaStringLength2", "test.org.apache.spark.sql.JavaStringLength"
-        )
-        [value] = self.spark.sql("SELECT javaStringLength2('test')").first()
-        self.assertEqual(value, 4)
+            self.spark.udf.registerJavaFunction(
+                "javaStringLength2", 
"test.org.apache.spark.sql.JavaStringLength"
+            )
+            [value] = self.spark.sql("SELECT 
javaStringLength2('test')").first()
+            self.assertEqual(value, 4)
 
-        self.spark.udf.registerJavaFunction(
-            "javaStringLength3", "test.org.apache.spark.sql.JavaStringLength", 
"integer"
-        )
-        [value] = self.spark.sql("SELECT javaStringLength3('test')").first()
-        self.assertEqual(value, 4)
+            self.spark.udf.registerJavaFunction(
+                "javaStringLength3", 
"test.org.apache.spark.sql.JavaStringLength", "integer"
+            )
+            [value] = self.spark.sql("SELECT 
javaStringLength3('test')").first()
+            self.assertEqual(value, 4)
 
     @unittest.skipIf(not test_compiled, test_not_compiled_message)  # type: 
ignore
     def test_register_java_udaf(self):
-        self.spark.udf.registerJavaUDAF("javaUDAF", 
"test.org.apache.spark.sql.MyDoubleAvg")
-        df = self.spark.createDataFrame([(1, "a"), (2, "b"), (3, "a")], ["id", 
"name"])
-        df.createOrReplaceTempView("df")
-        row = self.spark.sql(
-            "SELECT name, javaUDAF(id) as avg from df group by name order by 
name desc"
-        ).first()
-        self.assertEqual(row.asDict(), Row(name="b", avg=102.0).asDict())
+        with self.temp_func("javaUDAF"):
+            self.spark.udf.registerJavaUDAF("javaUDAF", 
"test.org.apache.spark.sql.MyDoubleAvg")
+            df = self.spark.createDataFrame([(1, "a"), (2, "b"), (3, "a")], 
["id", "name"])
+            df.createOrReplaceTempView("df")
+            row = self.spark.sql(
+                "SELECT name, javaUDAF(id) as avg from df group by name order 
by name desc"
+            ).first()
+            self.assertEqual(row.asDict(), Row(name="b", avg=102.0).asDict())
 
     def test_err_udf_registration(self):
-        with self.quiet():
-            self.check_err_udf_registration()
-
-    def check_err_udf_registration(self):
-        with self.assertRaises(PySparkTypeError) as pe:
-            self.spark.udf.register("f", UserDefinedFunction("x", 
StringType()), "int")
-
-        self.check_error(
-            exception=pe.exception,
-            errorClass="NOT_CALLABLE",
-            messageParameters={"arg_name": "func", "arg_type": "str"},
-        )
+        with self.quiet(), self.temp_func("f"):
+            with self.assertRaises(PySparkTypeError) as pe:
+                self.spark.udf.register("f", UserDefinedFunction("x", 
StringType()), "int")
+
+            self.check_error(
+                exception=pe.exception,
+                errorClass="NOT_CALLABLE",
+                messageParameters={"arg_name": "func", "arg_type": "str"},
+            )
 
     def test_non_existed_udf(self):
         spark = self.spark
@@ -1069,107 +1083,111 @@ class BaseUDFTestsMixin(object):
         def test_udf(a, b):
             return a + 10 * b
 
-        self.spark.udf.register("test_udf", test_udf)
-
-        for i, df in enumerate(
-            [
-                self.spark.range(2).select(test_udf(col("id"), b=col("id") * 
10)),
-                self.spark.range(2).select(test_udf(a=col("id"), b=col("id") * 
10)),
-                self.spark.range(2).select(test_udf(b=col("id") * 10, 
a=col("id"))),
-                self.spark.sql("SELECT test_udf(id, b => id * 10) FROM 
range(2)"),
-                self.spark.sql("SELECT test_udf(a => id, b => id * 10) FROM 
range(2)"),
-                self.spark.sql("SELECT test_udf(b => id * 10, a => id) FROM 
range(2)"),
-            ]
-        ):
-            with self.subTest(query_no=i):
-                assertDataFrameEqual(df, [Row(0), Row(101)])
+        with self.temp_func("test_udf"):
+            self.spark.udf.register("test_udf", test_udf)
+
+            for i, df in enumerate(
+                [
+                    self.spark.range(2).select(test_udf(col("id"), b=col("id") 
* 10)),
+                    self.spark.range(2).select(test_udf(a=col("id"), 
b=col("id") * 10)),
+                    self.spark.range(2).select(test_udf(b=col("id") * 10, 
a=col("id"))),
+                    self.spark.sql("SELECT test_udf(id, b => id * 10) FROM 
range(2)"),
+                    self.spark.sql("SELECT test_udf(a => id, b => id * 10) 
FROM range(2)"),
+                    self.spark.sql("SELECT test_udf(b => id * 10, a => id) 
FROM range(2)"),
+                ]
+            ):
+                with self.subTest(query_no=i):
+                    assertDataFrameEqual(df, [Row(0), Row(101)])
 
     def test_named_arguments_negative(self):
         @udf("int")
         def test_udf(a, b):
             return a + b
 
-        self.spark.udf.register("test_udf", test_udf)
+        with self.temp_func("test_udf"):
+            self.spark.udf.register("test_udf", test_udf)
 
-        with self.assertRaisesRegex(
-            AnalysisException,
-            
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
-        ):
-            self.spark.sql("SELECT test_udf(a => id, a => id * 10) FROM 
range(2)").show()
+            with self.assertRaisesRegex(
+                AnalysisException,
+                
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
+            ):
+                self.spark.sql("SELECT test_udf(a => id, a => id * 10) FROM 
range(2)").show()
 
-        with self.assertRaisesRegex(AnalysisException, 
"UNEXPECTED_POSITIONAL_ARGUMENT"):
-            self.spark.sql("SELECT test_udf(a => id, id * 10) FROM 
range(2)").show()
+            with self.assertRaisesRegex(AnalysisException, 
"UNEXPECTED_POSITIONAL_ARGUMENT"):
+                self.spark.sql("SELECT test_udf(a => id, id * 10) FROM 
range(2)").show()
 
-        with self.assertRaisesRegex(
-            PythonException, r"test_udf\(\) got an unexpected keyword argument 
'c'"
-        ):
-            self.spark.sql("SELECT test_udf(c => 'x') FROM range(2)").show()
+            with self.assertRaisesRegex(
+                PythonException, r"test_udf\(\) got an unexpected keyword 
argument 'c'"
+            ):
+                self.spark.sql("SELECT test_udf(c => 'x') FROM 
range(2)").show()
 
-        with self.assertRaisesRegex(
-            PythonException, r"test_udf\(\) got multiple values for argument 
'a'"
-        ):
-            self.spark.sql("SELECT test_udf(id, a => id * 10) FROM 
range(2)").show()
+            with self.assertRaisesRegex(
+                PythonException, r"test_udf\(\) got multiple values for 
argument 'a'"
+            ):
+                self.spark.sql("SELECT test_udf(id, a => id * 10) FROM 
range(2)").show()
 
     def test_kwargs(self):
         @udf("int")
         def test_udf(**kwargs):
             return kwargs["a"] + 10 * kwargs["b"]
 
-        self.spark.udf.register("test_udf", test_udf)
+        with self.temp_func("test_udf"):
+            self.spark.udf.register("test_udf", test_udf)
 
-        for i, df in enumerate(
-            [
-                self.spark.range(2).select(test_udf(a=col("id"), b=col("id") * 
10)),
-                self.spark.range(2).select(test_udf(b=col("id") * 10, 
a=col("id"))),
-                self.spark.sql("SELECT test_udf(a => id, b => id * 10) FROM 
range(2)"),
-                self.spark.sql("SELECT test_udf(b => id * 10, a => id) FROM 
range(2)"),
-            ]
-        ):
-            with self.subTest(query_no=i):
-                assertDataFrameEqual(df, [Row(0), Row(101)])
+            for i, df in enumerate(
+                [
+                    self.spark.range(2).select(test_udf(a=col("id"), 
b=col("id") * 10)),
+                    self.spark.range(2).select(test_udf(b=col("id") * 10, 
a=col("id"))),
+                    self.spark.sql("SELECT test_udf(a => id, b => id * 10) 
FROM range(2)"),
+                    self.spark.sql("SELECT test_udf(b => id * 10, a => id) 
FROM range(2)"),
+                ]
+            ):
+                with self.subTest(query_no=i):
+                    assertDataFrameEqual(df, [Row(0), Row(101)])
 
-        # negative
-        with self.assertRaisesRegex(
-            AnalysisException,
-            
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
-        ):
-            self.spark.sql("SELECT test_udf(a => id, a => id * 10) FROM 
range(2)").show()
+            # negative
+            with self.assertRaisesRegex(
+                AnalysisException,
+                
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
+            ):
+                self.spark.sql("SELECT test_udf(a => id, a => id * 10) FROM 
range(2)").show()
 
-        with self.assertRaisesRegex(AnalysisException, 
"UNEXPECTED_POSITIONAL_ARGUMENT"):
-            self.spark.sql("SELECT test_udf(a => id, id * 10) FROM 
range(2)").show()
+            with self.assertRaisesRegex(AnalysisException, 
"UNEXPECTED_POSITIONAL_ARGUMENT"):
+                self.spark.sql("SELECT test_udf(a => id, id * 10) FROM 
range(2)").show()
 
     def test_named_arguments_and_defaults(self):
         @udf("int")
         def test_udf(a, b=0):
             return a + 10 * b
 
-        self.spark.udf.register("test_udf", test_udf)
+        with self.temp_func("test_udf"):
+            self.spark.udf.register("test_udf", test_udf)
 
-        # without "b"
-        for i, df in enumerate(
-            [
-                self.spark.range(2).select(test_udf(col("id"))),
-                self.spark.range(2).select(test_udf(a=col("id"))),
-                self.spark.sql("SELECT test_udf(id) FROM range(2)"),
-                self.spark.sql("SELECT test_udf(a => id) FROM range(2)"),
-            ]
-        ):
-            with self.subTest(with_b=False, query_no=i):
-                assertDataFrameEqual(df, [Row(0), Row(1)])
-
-        # with "b"
-        for i, df in enumerate(
-            [
-                self.spark.range(2).select(test_udf(col("id"), b=col("id") * 
10)),
-                self.spark.range(2).select(test_udf(a=col("id"), b=col("id") * 
10)),
-                self.spark.range(2).select(test_udf(b=col("id") * 10, 
a=col("id"))),
-                self.spark.sql("SELECT test_udf(id, b => id * 10) FROM 
range(2)"),
-                self.spark.sql("SELECT test_udf(a => id, b => id * 10) FROM 
range(2)"),
-                self.spark.sql("SELECT test_udf(b => id * 10, a => id) FROM 
range(2)"),
-            ]
-        ):
-            with self.subTest(with_b=True, query_no=i):
-                assertDataFrameEqual(df, [Row(0), Row(101)])
+            # without "b"
+            for i, df in enumerate(
+                [
+                    self.spark.range(2).select(test_udf(col("id"))),
+                    self.spark.range(2).select(test_udf(a=col("id"))),
+                    self.spark.sql("SELECT test_udf(id) FROM range(2)"),
+                    self.spark.sql("SELECT test_udf(a => id) FROM range(2)"),
+                ]
+            ):
+                with self.subTest(with_b=False, query_no=i):
+                    assertDataFrameEqual(df, [Row(0), Row(1)])
+
+            # with "b"
+            for i, df in enumerate(
+                [
+                    self.spark.range(2).select(test_udf(col("id"), b=col("id") 
* 10)),
+                    self.spark.range(2).select(test_udf(a=col("id"), 
b=col("id") * 10)),
+                    self.spark.range(2).select(test_udf(b=col("id") * 10, 
a=col("id"))),
+                    self.spark.sql("SELECT test_udf(id, b => id * 10) FROM 
range(2)"),
+                    self.spark.sql("SELECT test_udf(a => id, b => id * 10) 
FROM range(2)"),
+                    self.spark.sql("SELECT test_udf(b => id * 10, a => id) 
FROM range(2)"),
+                ]
+            ):
+                with self.subTest(with_b=True, query_no=i):
+                    assertDataFrameEqual(df, [Row(0), Row(101)])
 
     def test_num_arguments(self):
         @udf("long")
@@ -1544,28 +1562,36 @@ class UDFTests(BaseUDFTestsMixin, ReusedSQLTestCase):
     # various batch size and see whether the query runs successfully, and the 
output is
     # consistent across different batch sizes.
     def test_udf_with_various_batch_size(self):
-        self.spark.catalog.registerFunction("twoArgs", lambda x, y: len(x) + 
y, IntegerType())
-        for batch_size in [1, 33, 1000, 2000]:
-            with 
self.sql_conf({"spark.sql.execution.python.udf.maxRecordsPerBatch": 
batch_size}):
-                df = self.spark.range(1000).selectExpr("twoArgs('test', id) AS 
ret").orderBy("ret")
-                rets = [x["ret"] for x in df.collect()]
-                self.assertEqual(rets, list(range(4, 1004)))
+        with self.temp_func("twoArgs"):
+            self.spark.catalog.registerFunction("twoArgs", lambda x, y: len(x) 
+ y, IntegerType())
+            for batch_size in [1, 33, 1000, 2000]:
+                with self.sql_conf(
+                    {"spark.sql.execution.python.udf.maxRecordsPerBatch": 
batch_size}
+                ):
+                    df = (
+                        self.spark.range(1000)
+                        .selectExpr("twoArgs('test', id) AS ret")
+                        .orderBy("ret")
+                    )
+                    rets = [x["ret"] for x in df.collect()]
+                    self.assertEqual(rets, list(range(4, 1004)))
 
     # We cannot check whether the buffer size is effective or not. We just run 
the query with
     # various buffer size and see whether the query runs successfully, and the 
output is
     # consistent across different batch sizes.
     def test_udf_with_various_buffer_size(self):
-        self.spark.catalog.registerFunction("twoArgs", lambda x, y: len(x) + 
y, IntegerType())
-        for batch_size in [1, 33, 10000]:
-            with self.sql_conf({"spark.sql.execution.python.udf.buffer.size": 
batch_size}):
-                df = (
-                    self.spark.range(1000)
-                    .repartition(1)
-                    .selectExpr("twoArgs('test', id) AS ret")
-                    .orderBy("ret")
-                )
-                rets = [x["ret"] for x in df.collect()]
-                self.assertEqual(rets, list(range(4, 1004)))
+        with self.temp_func("twoArgs"):
+            self.spark.catalog.registerFunction("twoArgs", lambda x, y: len(x) 
+ y, IntegerType())
+            for batch_size in [1, 33, 10000]:
+                with 
self.sql_conf({"spark.sql.execution.python.udf.buffer.size": batch_size}):
+                    df = (
+                        self.spark.range(1000)
+                        .repartition(1)
+                        .selectExpr("twoArgs('test', id) AS ret")
+                        .orderBy("ret")
+                    )
+                    rets = [x["ret"] for x in df.collect()]
+                    self.assertEqual(rets, list(range(4, 1004)))
 
 
 class UDFInitializationTests(unittest.TestCase):
diff --git a/python/pyspark/sql/tests/test_udf_profiler.py 
b/python/pyspark/sql/tests/test_udf_profiler.py
index de35532285df..4e8f722c22cb 100644
--- a/python/pyspark/sql/tests/test_udf_profiler.py
+++ b/python/pyspark/sql/tests/test_udf_profiler.py
@@ -263,15 +263,16 @@ class UDFProfiler2TestsMixin:
         def add1(x):
             return x + 1
 
-        self.spark.udf.register("add1", add1)
+        with self.temp_func("add1"):
+            self.spark.udf.register("add1", add1)
 
-        with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}):
-            self.spark.sql("SELECT id, add1(id) add1 FROM range(10)").collect()
+            with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}):
+                self.spark.sql("SELECT id, add1(id) add1 FROM 
range(10)").collect()
 
-        self.assertEqual(1, len(self.profile_results), 
str(self.profile_results.keys()))
+            self.assertEqual(1, len(self.profile_results), 
str(self.profile_results.keys()))
 
-        for id in self.profile_results:
-            self.assert_udf_profile_present(udf_id=id, 
expected_line_count_prefix=10)
+            for id in self.profile_results:
+                self.assert_udf_profile_present(udf_id=id, 
expected_line_count_prefix=10)
 
     @unittest.skipIf(
         not have_pandas or not have_pyarrow,
diff --git a/python/pyspark/sql/tests/test_unified_udf.py 
b/python/pyspark/sql/tests/test_unified_udf.py
index 3c105637c791..2d3446bd0b5b 100644
--- a/python/pyspark/sql/tests/test_unified_udf.py
+++ b/python/pyspark/sql/tests/test_unified_udf.py
@@ -53,11 +53,10 @@ class UnifiedUDFTestsMixin:
         result1 = df.select(pd_add1("id").alias("res")).collect()
         self.assertEqual(result1, expected)
 
-        self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS pd_add1")
-        self.spark.udf.register("pd_add1", pd_add1)
-        result2 = self.spark.sql("SELECT pd_add1(id) AS res FROM range(0, 
10)").collect()
-        self.assertEqual(result2, expected)
-        self.spark.sql("DROP TEMPORARY FUNCTION pd_add1")
+        with self.temp_func("pd_add1"):
+            self.spark.udf.register("pd_add1", pd_add1)
+            result2 = self.spark.sql("SELECT pd_add1(id) AS res FROM range(0, 
10)").collect()
+            self.assertEqual(result2, expected)
 
     def test_scalar_pandas_udf_II(self):
         import pandas as pd
@@ -76,11 +75,10 @@ class UnifiedUDFTestsMixin:
         result1 = df.select(pd_add("id", "id").alias("res")).collect()
         self.assertEqual(result1, expected)
 
-        self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS pd_add")
-        self.spark.udf.register("pd_add", pd_add)
-        result2 = self.spark.sql("SELECT pd_add(id, id) AS res FROM range(0, 
10)").collect()
-        self.assertEqual(result2, expected)
-        self.spark.sql("DROP TEMPORARY FUNCTION pd_add")
+        with self.temp_func("pd_add"):
+            self.spark.udf.register("pd_add", pd_add)
+            result2 = self.spark.sql("SELECT pd_add(id, id) AS res FROM 
range(0, 10)").collect()
+            self.assertEqual(result2, expected)
 
     def test_scalar_pandas_iter_udf(self):
         import pandas as pd
@@ -99,11 +97,10 @@ class UnifiedUDFTestsMixin:
         result1 = df.select(pd_add1_iter("id").alias("res")).collect()
         self.assertEqual(result1, expected)
 
-        self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS pd_add1_iter")
-        self.spark.udf.register("pd_add1_iter", pd_add1_iter)
-        result2 = self.spark.sql("SELECT pd_add1_iter(id) AS res FROM range(0, 
10)").collect()
-        self.assertEqual(result2, expected)
-        self.spark.sql("DROP TEMPORARY FUNCTION pd_add1_iter")
+        with self.temp_func("pd_add1_iter"):
+            self.spark.udf.register("pd_add1_iter", pd_add1_iter)
+            result2 = self.spark.sql("SELECT pd_add1_iter(id) AS res FROM 
range(0, 10)").collect()
+            self.assertEqual(result2, expected)
 
     def test_scalar_pandas_iter_udf_II(self):
         import pandas as pd
@@ -123,11 +120,12 @@ class UnifiedUDFTestsMixin:
         result1 = df.select(pd_add_iter("id", "id").alias("res")).collect()
         self.assertEqual(result1, expected)
 
-        self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS pd_add_iter")
-        self.spark.udf.register("pd_add_iter", pd_add_iter)
-        result2 = self.spark.sql("SELECT pd_add_iter(id, id) AS res FROM 
range(0, 10)").collect()
-        self.assertEqual(result2, expected)
-        self.spark.sql("DROP TEMPORARY FUNCTION pd_add_iter")
+        with self.temp_func("pd_add_iter"):
+            self.spark.udf.register("pd_add_iter", pd_add_iter)
+            result2 = self.spark.sql(
+                "SELECT pd_add_iter(id, id) AS res FROM range(0, 10)"
+            ).collect()
+            self.assertEqual(result2, expected)
 
     def test_grouped_agg_pandas_udf(self):
         import pandas as pd
@@ -145,11 +143,10 @@ class UnifiedUDFTestsMixin:
         result1 = df.select(pd_max("id").alias("res")).collect()
         self.assertEqual(result1, expected)
 
-        self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS pd_max")
-        self.spark.udf.register("pd_max", pd_max)
-        result2 = self.spark.sql("SELECT pd_max(id) AS res FROM range(0, 
10)").collect()
-        self.assertEqual(result2, expected)
-        self.spark.sql("DROP TEMPORARY FUNCTION pd_max")
+        with self.temp_func("pd_max"):
+            self.spark.udf.register("pd_max", pd_max)
+            result2 = self.spark.sql("SELECT pd_max(id) AS res FROM range(0, 
10)").collect()
+            self.assertEqual(result2, expected)
 
     def test_window_agg_pandas_udf(self):
         import pandas as pd
@@ -180,9 +177,8 @@ class UnifiedUDFTestsMixin:
         result1 = df.withColumn("res", pd_win_max("v").over(w)).collect()
         self.assertEqual(result1, expected)
 
-        with self.tempView("pd_tbl"):
+        with self.tempView("pd_tbl"), self.temp_func("pd_win_max"):
             df.createOrReplaceTempView("pd_tbl")
-            self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS pd_win_max")
             self.spark.udf.register("pd_win_max", pd_win_max)
 
             result2 = self.spark.sql(
@@ -195,7 +191,6 @@ class UnifiedUDFTestsMixin:
                 """
             ).collect()
             self.assertEqual(result2, expected)
-            self.spark.sql("DROP TEMPORARY FUNCTION pd_win_max")
 
     def test_scalar_arrow_udf(self):
         import pyarrow as pa
@@ -213,11 +208,10 @@ class UnifiedUDFTestsMixin:
         result1 = df.select(pa_add1("id").alias("res")).collect()
         self.assertEqual(result1, expected)
 
-        self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS pa_add1")
-        self.spark.udf.register("pa_add1", pa_add1)
-        result2 = self.spark.sql("SELECT pa_add1(id) AS res FROM range(0, 
10)").collect()
-        self.assertEqual(result2, expected)
-        self.spark.sql("DROP TEMPORARY FUNCTION pa_add1")
+        with self.temp_func("pa_add1"):
+            self.spark.udf.register("pa_add1", pa_add1)
+            result2 = self.spark.sql("SELECT pa_add1(id) AS res FROM range(0, 
10)").collect()
+            self.assertEqual(result2, expected)
 
     def test_scalar_arrow_udf_II(self):
         import pyarrow as pa
@@ -236,11 +230,10 @@ class UnifiedUDFTestsMixin:
         result1 = df.select(pa_add("id", "id").alias("res")).collect()
         self.assertEqual(result1, expected)
 
-        self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS pa_add")
-        self.spark.udf.register("pa_add", pa_add)
-        result2 = self.spark.sql("SELECT pa_add(id, id) AS res FROM range(0, 
10)").collect()
-        self.assertEqual(result2, expected)
-        self.spark.sql("DROP TEMPORARY FUNCTION pa_add")
+        with self.temp_func("pa_add"):
+            self.spark.udf.register("pa_add", pa_add)
+            result2 = self.spark.sql("SELECT pa_add(id, id) AS res FROM 
range(0, 10)").collect()
+            self.assertEqual(result2, expected)
 
     def test_scalar_arrow_iter_udf(self):
         import pyarrow as pa
@@ -259,11 +252,10 @@ class UnifiedUDFTestsMixin:
         result1 = df.select(pa_add1_iter("id").alias("res")).collect()
         self.assertEqual(result1, expected)
 
-        self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS pa_add1_iter")
-        self.spark.udf.register("pa_add1_iter", pa_add1_iter)
-        result2 = self.spark.sql("SELECT pa_add1_iter(id) AS res FROM range(0, 
10)").collect()
-        self.assertEqual(result2, expected)
-        self.spark.sql("DROP TEMPORARY FUNCTION pa_add1_iter")
+        with self.temp_func("pa_add1_iter"):
+            self.spark.udf.register("pa_add1_iter", pa_add1_iter)
+            result2 = self.spark.sql("SELECT pa_add1_iter(id) AS res FROM 
range(0, 10)").collect()
+            self.assertEqual(result2, expected)
 
     def test_scalar_arrow_iter_udf_II(self):
         import pyarrow as pa
@@ -283,11 +275,12 @@ class UnifiedUDFTestsMixin:
         result1 = df.select(pa_add_iter("id", "id").alias("res")).collect()
         self.assertEqual(result1, expected)
 
-        self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS pa_add_iter")
-        self.spark.udf.register("pa_add_iter", pa_add_iter)
-        result2 = self.spark.sql("SELECT pa_add_iter(id, id) AS res FROM 
range(0, 10)").collect()
-        self.assertEqual(result2, expected)
-        self.spark.sql("DROP TEMPORARY FUNCTION pa_add_iter")
+        with self.temp_func("pa_add_iter"):
+            self.spark.udf.register("pa_add_iter", pa_add_iter)
+            result2 = self.spark.sql(
+                "SELECT pa_add_iter(id, id) AS res FROM range(0, 10)"
+            ).collect()
+            self.assertEqual(result2, expected)
 
     def test_grouped_agg_arrow_udf(self):
         import pyarrow as pa
@@ -305,11 +298,10 @@ class UnifiedUDFTestsMixin:
         result1 = df.select(pa_max("id").alias("res")).collect()
         self.assertEqual(result1, expected)
 
-        self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS pa_max")
-        self.spark.udf.register("pa_max", pa_max)
-        result2 = self.spark.sql("SELECT pa_max(id) AS res FROM range(0, 
10)").collect()
-        self.assertEqual(result2, expected)
-        self.spark.sql("DROP TEMPORARY FUNCTION pa_max")
+        with self.temp_func("pa_max"):
+            self.spark.udf.register("pa_max", pa_max)
+            result2 = self.spark.sql("SELECT pa_max(id) AS res FROM range(0, 
10)").collect()
+            self.assertEqual(result2, expected)
 
     def test_window_agg_arrow_udf(self):
         import pyarrow as pa
@@ -340,9 +332,8 @@ class UnifiedUDFTestsMixin:
         result1 = df.withColumn("mean_v", pa_win_max("v").over(w)).collect()
         self.assertEqual(result1, expected)
 
-        with self.tempView("pa_tbl"):
+        with self.tempView("pa_tbl"), self.temp_func("pa_win_max"):
             df.createOrReplaceTempView("pa_tbl")
-            self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS pa_win_max")
             self.spark.udf.register("pa_win_max", pa_win_max)
 
             result2 = self.spark.sql(
@@ -355,7 +346,6 @@ class UnifiedUDFTestsMixin:
                 """
             ).collect()
             self.assertEqual(result2, expected)
-            self.spark.sql("DROP TEMPORARY FUNCTION pa_win_max")
 
     def test_regular_python_udf(self):
         import pandas as pd
diff --git a/python/pyspark/testing/sqlutils.py 
b/python/pyspark/testing/sqlutils.py
index 98d04e7d5b1a..645bf1f2ea80 100644
--- a/python/pyspark/testing/sqlutils.py
+++ b/python/pyspark/testing/sqlutils.py
@@ -147,6 +147,20 @@ class SQLTestUtils:
             for v in views:
                 self.spark.catalog.dropTempView(v)
 
+    @contextmanager
+    def temp_func(self, *functions):
+        """
+        A convenient context manager to test with some specific temporary 
functions.
+        This drops the temporary functions if it exists.
+        """
+        assert hasattr(self, "spark"), "it should have 'spark' attribute, 
having a spark session."
+
+        try:
+            yield
+        finally:
+            for f in functions:
+                self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS %s" % f)
+
     @contextmanager
     def function(self, *functions):
         """
diff --git a/python/pyspark/tests/test_memory_profiler.py 
b/python/pyspark/tests/test_memory_profiler.py
index 144442b5a48f..df9d63c5260f 100644
--- a/python/pyspark/tests/test_memory_profiler.py
+++ b/python/pyspark/tests/test_memory_profiler.py
@@ -302,15 +302,16 @@ class MemoryProfiler2TestsMixin:
         def add1(x):
             return x + 1
 
-        self.spark.udf.register("add1", add1)
+        with self.temp_func("add1"):
+            self.spark.udf.register("add1", add1)
 
-        with self.sql_conf({"spark.sql.pyspark.udf.profiler": "memory"}):
-            self.spark.sql("SELECT id, add1(id) add1 FROM range(10)").collect()
+            with self.sql_conf({"spark.sql.pyspark.udf.profiler": "memory"}):
+                self.spark.sql("SELECT id, add1(id) add1 FROM 
range(10)").collect()
 
-        self.assertEqual(1, len(self.profile_results), 
str(self.profile_results.keys()))
+            self.assertEqual(1, len(self.profile_results), 
str(self.profile_results.keys()))
 
-        for id in self.profile_results:
-            self.assert_udf_memory_profile_present(udf_id=id)
+            for id in self.profile_results:
+                self.assert_udf_memory_profile_present(udf_id=id)
 
     @unittest.skipIf(
         not have_pandas or not have_pyarrow,


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to