This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 748de5f3710e [SPARK-53969][PYTHON][TESTS] Drop temporary functions in 
Arrow UDF tests
748de5f3710e is described below

commit 748de5f3710e9401d4ae854ca17fe6060c080caa
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Tue Oct 21 10:38:30 2025 -0700

    [SPARK-53969][PYTHON][TESTS] Drop temporary functions in Arrow UDF tests
    
    ### What changes were proposed in this pull request?
    Drop temporary functions in Arrow UDF tests
    
    ### Why are the changes needed?
    to avoid the env being polluted
    
    ### Does this PR introduce _any_ user-facing change?
    no, test-only
    
    ### How was this patch tested?
    ci
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #52682 from zhengruifeng/with_temp_func_arrow.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Dongjoon Hyun <[email protected]>
---
 .../sql/tests/arrow/test_arrow_python_udf.py       |  53 ++--
 .../sql/tests/arrow/test_arrow_udf_grouped_agg.py  |  31 ++-
 .../sql/tests/arrow/test_arrow_udf_scalar.py       | 270 +++++++++++----------
 .../sql/tests/arrow/test_arrow_udf_window.py       |   6 +-
 4 files changed, 188 insertions(+), 172 deletions(-)

diff --git a/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py 
b/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py
index 8ab02875cb3a..ba1fd3e4e0a9 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py
@@ -128,20 +128,24 @@ class ArrowPythonUDFTestsMixin(BaseUDFTestsMixin):
         df = self.spark.range(1).selectExpr(
             "array(1, 2, 3) as array",
         )
-        str_repr_func = self.spark.udf.register("str_repr", udf(lambda x: 
str(x), useArrow=True))
 
-        # To verify that Arrow optimization is on
-        self.assertIn(
-            df.selectExpr("str_repr(array) AS str_id").first()[0],
-            ["[1, 2, 3]", "[np.int32(1), np.int32(2), np.int32(3)]"],
-            # The input is a NumPy array when the Arrow optimization is on
-        )
+        with self.temp_func("str_repr"):
+            str_repr_func = self.spark.udf.register(
+                "str_repr", udf(lambda x: str(x), useArrow=True)
+            )
 
-        # To verify that a UserDefinedFunction is returned
-        self.assertListEqual(
-            df.selectExpr("str_repr(array) AS str_id").collect(),
-            df.select(str_repr_func("array").alias("str_id")).collect(),
-        )
+            # To verify that Arrow optimization is on
+            self.assertIn(
+                df.selectExpr("str_repr(array) AS str_id").first()[0],
+                ["[1, 2, 3]", "[np.int32(1), np.int32(2), np.int32(3)]"],
+                # The input is a NumPy array when the Arrow optimization is on
+            )
+
+            # To verify that a UserDefinedFunction is returned
+            self.assertListEqual(
+                df.selectExpr("str_repr(array) AS str_id").collect(),
+                df.select(str_repr_func("array").alias("str_id")).collect(),
+            )
 
     def test_nested_array_input(self):
         df = self.spark.range(1).selectExpr("array(array(1, 2), array(3, 4)) 
as nested_array")
@@ -275,22 +279,23 @@ class ArrowPythonUDFTestsMixin(BaseUDFTestsMixin):
         def test_udf(a, b):
             return a + b
 
-        self.spark.udf.register("test_udf", test_udf)
+        with self.temp_func("test_udf"):
+            self.spark.udf.register("test_udf", test_udf)
 
-        with self.assertRaisesRegex(
-            AnalysisException,
-            
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
-        ):
-            self.spark.sql("SELECT test_udf(a => id, a => id * 10) FROM 
range(2)").show()
+            with self.assertRaisesRegex(
+                AnalysisException,
+                
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
+            ):
+                self.spark.sql("SELECT test_udf(a => id, a => id * 10) FROM 
range(2)").show()
 
-        with self.assertRaisesRegex(AnalysisException, 
"UNEXPECTED_POSITIONAL_ARGUMENT"):
-            self.spark.sql("SELECT test_udf(a => id, id * 10) FROM 
range(2)").show()
+            with self.assertRaisesRegex(AnalysisException, 
"UNEXPECTED_POSITIONAL_ARGUMENT"):
+                self.spark.sql("SELECT test_udf(a => id, id * 10) FROM 
range(2)").show()
 
-        with self.assertRaises(PythonException):
-            self.spark.sql("SELECT test_udf(c => 'x') FROM range(2)").show()
+            with self.assertRaises(PythonException):
+                self.spark.sql("SELECT test_udf(c => 'x') FROM 
range(2)").show()
 
-        with self.assertRaises(PythonException):
-            self.spark.sql("SELECT test_udf(id, a => id * 10) FROM 
range(2)").show()
+            with self.assertRaises(PythonException):
+                self.spark.sql("SELECT test_udf(id, a => id * 10) FROM 
range(2)").show()
 
     def test_udf_with_udt(self):
         for fallback in [False, True]:
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py 
b/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py
index fae9650b2864..136a99e19411 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py
@@ -488,12 +488,19 @@ class GroupedAggArrowUDFTestsMixin:
         )
 
         self.assertEqual(sum_arrow_udf.evalType, 
PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF)
-        group_agg_pandas_udf = self.spark.udf.register("sum_arrow_udf", 
sum_arrow_udf)
-        self.assertEqual(group_agg_pandas_udf.evalType, 
PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF)
-        q = "SELECT sum_arrow_udf(v1) FROM VALUES (3, 0), (2, 0), (1, 1) 
tbl(v1, v2) GROUP BY v2"
-        actual = sorted(map(lambda r: r[0], self.spark.sql(q).collect()))
-        expected = [1, 5]
-        self.assertEqual(actual, expected)
+
+        with self.temp_func("sum_arrow_udf"):
+            group_agg_pandas_udf = self.spark.udf.register("sum_arrow_udf", 
sum_arrow_udf)
+            self.assertEqual(
+                group_agg_pandas_udf.evalType, 
PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF
+            )
+            q = """
+                SELECT sum_arrow_udf(v1)
+                FROM VALUES (3, 0), (2, 0), (1, 1) tbl(v1, v2) GROUP BY v2
+                """
+            actual = sorted(map(lambda r: r[0], self.spark.sql(q).collect()))
+            expected = [1, 5]
+            self.assertEqual(actual, expected)
 
     def test_grouped_with_empty_partition(self):
         import pyarrow as pa
@@ -516,10 +523,10 @@ class GroupedAggArrowUDFTestsMixin:
             return float(pa.compute.max(v).as_py())
 
         df = self.spark.range(0, 100)
-        self.spark.udf.register("max_udf", max_udf)
 
-        with self.tempView("table"):
+        with self.tempView("table"), self.temp_func("max_udf"):
             df.createTempView("table")
+            self.spark.udf.register("max_udf", max_udf)
 
             agg1 = df.agg(max_udf(df["id"]))
             agg2 = self.spark.sql("select max_udf(id) from table")
@@ -546,7 +553,7 @@ class GroupedAggArrowUDFTestsMixin:
         df = self.data
         weighted_mean = self.arrow_agg_weighted_mean_udf
 
-        with self.tempView("v"):
+        with self.tempView("v"), self.temp_func("weighted_mean"):
             df.createOrReplaceTempView("v")
             self.spark.udf.register("weighted_mean", weighted_mean)
 
@@ -575,7 +582,7 @@ class GroupedAggArrowUDFTestsMixin:
         df = self.data
         weighted_mean = self.arrow_agg_weighted_mean_udf
 
-        with self.tempView("v"):
+        with self.tempView("v"), self.temp_func("weighted_mean"):
             df.createOrReplaceTempView("v")
             self.spark.udf.register("weighted_mean", weighted_mean)
 
@@ -615,7 +622,7 @@ class GroupedAggArrowUDFTestsMixin:
 
             return np.average(kwargs["v"], weights=kwargs["w"])
 
-        with self.tempView("v"):
+        with self.tempView("v"), self.temp_func("weighted_mean"):
             df.createOrReplaceTempView("v")
             self.spark.udf.register("weighted_mean", weighted_mean)
 
@@ -660,7 +667,7 @@ class GroupedAggArrowUDFTestsMixin:
         def biased_sum(v, w=None):
             return pa.compute.sum(v).as_py() + (pa.compute.sum(w).as_py() if w 
is not None else 100)
 
-        with self.tempView("v"):
+        with self.tempView("v"), self.temp_func("biased_sum"):
             df.createOrReplaceTempView("v")
             self.spark.udf.register("biased_sum", biased_sum)
 
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py 
b/python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py
index 75ab8e2ff15b..a682c6515ef6 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py
@@ -707,38 +707,38 @@ class ScalarArrowUDFTestsMixin:
         self.assertEqual(scalar_original_add.evalType, 
PythonEvalType.SQL_SCALAR_ARROW_UDF)
         self.assertEqual(scalar_original_add.deterministic, True)
 
-        self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS add1")
-        new_add = self.spark.udf.register("add1", scalar_original_add)
+        with self.temp_func("add1"):
+            new_add = self.spark.udf.register("add1", scalar_original_add)
 
-        self.assertEqual(new_add.deterministic, True)
-        self.assertEqual(new_add.evalType, PythonEvalType.SQL_SCALAR_ARROW_UDF)
+            self.assertEqual(new_add.deterministic, True)
+            self.assertEqual(new_add.evalType, 
PythonEvalType.SQL_SCALAR_ARROW_UDF)
 
-        df = self.spark.range(10).select(
-            F.col("id").cast("int").alias("a"), 
F.col("id").cast("int").alias("b")
-        )
-        res1 = df.select(new_add(F.col("a"), F.col("b")))
-        res2 = self.spark.sql(
-            "SELECT add1(t.a, t.b) FROM (SELECT id as a, id as b FROM 
range(10)) t"
-        )
-        expected = df.select(F.expr("a + b"))
-        self.assertEqual(expected.collect(), res1.collect())
-        self.assertEqual(expected.collect(), res2.collect())
+            df = self.spark.range(10).select(
+                F.col("id").cast("int").alias("a"), 
F.col("id").cast("int").alias("b")
+            )
+            res1 = df.select(new_add(F.col("a"), F.col("b")))
+            res2 = self.spark.sql(
+                "SELECT add1(t.a, t.b) FROM (SELECT id as a, id as b FROM 
range(10)) t"
+            )
+            expected = df.select(F.expr("a + b"))
+            self.assertEqual(expected.collect(), res1.collect())
+            self.assertEqual(expected.collect(), res2.collect())
 
         @arrow_udf(LongType())
         def scalar_iter_add(it: Iterator[Tuple[pa.Array, pa.Array]]) -> 
Iterator[pa.Array]:
             for a, b in it:
                 yield pa.compute.add(a, b)
 
-        self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS add1")
-        new_add = self.spark.udf.register("add1", scalar_iter_add)
+        with self.temp_func("add1"):
+            new_add = self.spark.udf.register("add1", scalar_iter_add)
 
-        res3 = df.select(new_add(F.col("a"), F.col("b")))
-        res4 = self.spark.sql(
-            "SELECT add1(t.a, t.b) FROM (SELECT id as a, id as b FROM 
range(10)) t"
-        )
-        expected = df.select(F.expr("a + b"))
-        self.assertEqual(expected.collect(), res3.collect())
-        self.assertEqual(expected.collect(), res4.collect())
+            res3 = df.select(new_add(F.col("a"), F.col("b")))
+            res4 = self.spark.sql(
+                "SELECT add1(t.a, t.b) FROM (SELECT id as a, id as b FROM 
range(10)) t"
+            )
+            expected = df.select(F.expr("a + b"))
+            self.assertEqual(expected.collect(), res3.collect())
+            self.assertEqual(expected.collect(), res4.collect())
 
     def test_catalog_register_arrow_udf_basic(self):
         import pyarrow as pa
@@ -749,38 +749,38 @@ class ScalarArrowUDFTestsMixin:
         self.assertEqual(scalar_original_add.evalType, 
PythonEvalType.SQL_SCALAR_ARROW_UDF)
         self.assertEqual(scalar_original_add.deterministic, True)
 
-        self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS add1")
-        new_add = self.spark.catalog.registerFunction("add1", 
scalar_original_add)
+        with self.temp_func("add1"):
+            new_add = self.spark.catalog.registerFunction("add1", 
scalar_original_add)
 
-        self.assertEqual(new_add.deterministic, True)
-        self.assertEqual(new_add.evalType, PythonEvalType.SQL_SCALAR_ARROW_UDF)
+            self.assertEqual(new_add.deterministic, True)
+            self.assertEqual(new_add.evalType, 
PythonEvalType.SQL_SCALAR_ARROW_UDF)
 
-        df = self.spark.range(10).select(
-            F.col("id").cast("int").alias("a"), 
F.col("id").cast("int").alias("b")
-        )
-        res1 = df.select(new_add(F.col("a"), F.col("b")))
-        res2 = self.spark.sql(
-            "SELECT add1(t.a, t.b) FROM (SELECT id as a, id as b FROM 
range(10)) t"
-        )
-        expected = df.select(F.expr("a + b"))
-        self.assertEqual(expected.collect(), res1.collect())
-        self.assertEqual(expected.collect(), res2.collect())
+            df = self.spark.range(10).select(
+                F.col("id").cast("int").alias("a"), 
F.col("id").cast("int").alias("b")
+            )
+            res1 = df.select(new_add(F.col("a"), F.col("b")))
+            res2 = self.spark.sql(
+                "SELECT add1(t.a, t.b) FROM (SELECT id as a, id as b FROM 
range(10)) t"
+            )
+            expected = df.select(F.expr("a + b"))
+            self.assertEqual(expected.collect(), res1.collect())
+            self.assertEqual(expected.collect(), res2.collect())
 
         @arrow_udf(LongType())
         def scalar_iter_add(it: Iterator[Tuple[pa.Array, pa.Array]]) -> 
Iterator[pa.Array]:
             for a, b in it:
                 yield pa.compute.add(a, b)
 
-        self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS add1")
-        new_add = self.spark.catalog.registerFunction("add1", scalar_iter_add)
+        with self.temp_func("add1"):
+            new_add = self.spark.catalog.registerFunction("add1", 
scalar_iter_add)
 
-        res3 = df.select(new_add(F.col("a"), F.col("b")))
-        res4 = self.spark.sql(
-            "SELECT add1(t.a, t.b) FROM (SELECT id as a, id as b FROM 
range(10)) t"
-        )
-        expected = df.select(F.expr("a + b"))
-        self.assertEqual(expected.collect(), res3.collect())
-        self.assertEqual(expected.collect(), res4.collect())
+            res3 = df.select(new_add(F.col("a"), F.col("b")))
+            res4 = self.spark.sql(
+                "SELECT add1(t.a, t.b) FROM (SELECT id as a, id as b FROM 
range(10)) t"
+            )
+            expected = df.select(F.expr("a + b"))
+            self.assertEqual(expected.collect(), res3.collect())
+            self.assertEqual(expected.collect(), res4.collect())
 
     def test_udf_register_nondeterministic_arrow_udf(self):
         import pyarrow as pa
@@ -791,13 +791,15 @@ class ScalarArrowUDFTestsMixin:
         self.assertEqual(random_arrow_udf.deterministic, False)
         self.assertEqual(random_arrow_udf.evalType, 
PythonEvalType.SQL_SCALAR_ARROW_UDF)
 
-        self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS randomArrowUDF")
-        nondeterministic_arrow_udf = self.spark.udf.register("randomArrowUDF", 
random_arrow_udf)
+        with self.temp_func("randomArrowUDF"):
+            nondeterministic_arrow_udf = 
self.spark.udf.register("randomArrowUDF", random_arrow_udf)
 
-        self.assertEqual(nondeterministic_arrow_udf.deterministic, False)
-        self.assertEqual(nondeterministic_arrow_udf.evalType, 
PythonEvalType.SQL_SCALAR_ARROW_UDF)
-        [row] = self.spark.sql("SELECT randomArrowUDF(1)").collect()
-        self.assertEqual(row[0], 7)
+            self.assertEqual(nondeterministic_arrow_udf.deterministic, False)
+            self.assertEqual(
+                nondeterministic_arrow_udf.evalType, 
PythonEvalType.SQL_SCALAR_ARROW_UDF
+            )
+            [row] = self.spark.sql("SELECT randomArrowUDF(1)").collect()
+            self.assertEqual(row[0], 7)
 
     def test_catalog_register_nondeterministic_arrow_udf(self):
         import pyarrow as pa
@@ -808,15 +810,17 @@ class ScalarArrowUDFTestsMixin:
         self.assertEqual(random_arrow_udf.deterministic, False)
         self.assertEqual(random_arrow_udf.evalType, 
PythonEvalType.SQL_SCALAR_ARROW_UDF)
 
-        self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS randomArrowUDF")
-        nondeterministic_arrow_udf = self.spark.catalog.registerFunction(
-            "randomArrowUDF", random_arrow_udf
-        )
+        with self.temp_func("randomArrowUDF"):
+            nondeterministic_arrow_udf = self.spark.catalog.registerFunction(
+                "randomArrowUDF", random_arrow_udf
+            )
 
-        self.assertEqual(nondeterministic_arrow_udf.deterministic, False)
-        self.assertEqual(nondeterministic_arrow_udf.evalType, 
PythonEvalType.SQL_SCALAR_ARROW_UDF)
-        [row] = self.spark.sql("SELECT randomArrowUDF(1)").collect()
-        self.assertEqual(row[0], 7)
+            self.assertEqual(nondeterministic_arrow_udf.deterministic, False)
+            self.assertEqual(
+                nondeterministic_arrow_udf.evalType, 
PythonEvalType.SQL_SCALAR_ARROW_UDF
+            )
+            [row] = self.spark.sql("SELECT randomArrowUDF(1)").collect()
+            self.assertEqual(row[0], 7)
 
     @unittest.skipIf(not have_numpy, numpy_requirement_message)
     def test_nondeterministic_arrow_udf(self):
@@ -981,22 +985,22 @@ class ScalarArrowUDFTestsMixin:
         def test_udf(a, b):
             return pa.compute.add(a, pa.compute.multiply(b, 
10)).cast(pa.int32())
 
-        self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS test_udf")
-        self.spark.udf.register("test_udf", test_udf)
-
-        expected = [Row(0), Row(101)]
-        for i, df in enumerate(
-            [
-                self.spark.range(2).select(test_udf(F.col("id"), b=F.col("id") 
* 10)),
-                self.spark.range(2).select(test_udf(a=F.col("id"), 
b=F.col("id") * 10)),
-                self.spark.range(2).select(test_udf(b=F.col("id") * 10, 
a=F.col("id"))),
-                self.spark.sql("SELECT test_udf(id, b => id * 10) FROM 
range(2)"),
-                self.spark.sql("SELECT test_udf(a => id, b => id * 10) FROM 
range(2)"),
-                self.spark.sql("SELECT test_udf(b => id * 10, a => id) FROM 
range(2)"),
-            ]
-        ):
-            with self.subTest(query_no=i):
-                self.assertEqual(expected, df.collect())
+        with self.temp_func("test_udf"):
+            self.spark.udf.register("test_udf", test_udf)
+
+            expected = [Row(0), Row(101)]
+            for i, df in enumerate(
+                [
+                    self.spark.range(2).select(test_udf(F.col("id"), 
b=F.col("id") * 10)),
+                    self.spark.range(2).select(test_udf(a=F.col("id"), 
b=F.col("id") * 10)),
+                    self.spark.range(2).select(test_udf(b=F.col("id") * 10, 
a=F.col("id"))),
+                    self.spark.sql("SELECT test_udf(id, b => id * 10) FROM 
range(2)"),
+                    self.spark.sql("SELECT test_udf(a => id, b => id * 10) 
FROM range(2)"),
+                    self.spark.sql("SELECT test_udf(b => id * 10, a => id) 
FROM range(2)"),
+                ]
+            ):
+                with self.subTest(query_no=i):
+                    self.assertEqual(expected, df.collect())
 
     def test_arrow_udf_named_arguments_negative(self):
         import pyarrow as pa
@@ -1005,22 +1009,22 @@ class ScalarArrowUDFTestsMixin:
         def test_udf(a, b):
             return pa.compute.add(a, b).cast(pa.int32())
 
-        self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS test_udf")
-        self.spark.udf.register("test_udf", test_udf)
+        with self.temp_func("test_udf"):
+            self.spark.udf.register("test_udf", test_udf)
 
-        with self.assertRaisesRegex(
-            AnalysisException,
-            
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
-        ):
-            self.spark.sql("SELECT test_udf(a => id, a => id * 10) FROM 
range(2)").show()
+            with self.assertRaisesRegex(
+                AnalysisException,
+                
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
+            ):
+                self.spark.sql("SELECT test_udf(a => id, a => id * 10) FROM 
range(2)").show()
 
-        with self.assertRaisesRegex(AnalysisException, 
"UNEXPECTED_POSITIONAL_ARGUMENT"):
-            self.spark.sql("SELECT test_udf(a => id, id * 10) FROM 
range(2)").show()
+            with self.assertRaisesRegex(AnalysisException, 
"UNEXPECTED_POSITIONAL_ARGUMENT"):
+                self.spark.sql("SELECT test_udf(a => id, id * 10) FROM 
range(2)").show()
 
-        with self.assertRaisesRegex(
-            PythonException, r"test_udf\(\) got an unexpected keyword argument 
'c'"
-        ):
-            self.spark.sql("SELECT test_udf(c => 'x') FROM range(2)").show()
+            with self.assertRaisesRegex(
+                PythonException, r"test_udf\(\) got an unexpected keyword 
argument 'c'"
+            ):
+                self.spark.sql("SELECT test_udf(c => 'x') FROM 
range(2)").show()
 
     def test_arrow_udf_named_arguments_and_defaults(self):
         import pyarrow as pa
@@ -1029,36 +1033,36 @@ class ScalarArrowUDFTestsMixin:
         def test_udf(a, b=0):
             return pa.compute.add(a, pa.compute.multiply(b, 
10)).cast(pa.int32())
 
-        self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS test_udf")
-        self.spark.udf.register("test_udf", test_udf)
-
-        # without "b"
-        expected = [Row(0), Row(1)]
-        for i, df in enumerate(
-            [
-                self.spark.range(2).select(test_udf(F.col("id"))),
-                self.spark.range(2).select(test_udf(a=F.col("id"))),
-                self.spark.sql("SELECT test_udf(id) FROM range(2)"),
-                self.spark.sql("SELECT test_udf(a => id) FROM range(2)"),
-            ]
-        ):
-            with self.subTest(with_b=False, query_no=i):
-                self.assertEqual(expected, df.collect())
-
-        # with "b"
-        expected = [Row(0), Row(101)]
-        for i, df in enumerate(
-            [
-                self.spark.range(2).select(test_udf(F.col("id"), b=F.col("id") 
* 10)),
-                self.spark.range(2).select(test_udf(a=F.col("id"), 
b=F.col("id") * 10)),
-                self.spark.range(2).select(test_udf(b=F.col("id") * 10, 
a=F.col("id"))),
-                self.spark.sql("SELECT test_udf(id, b => id * 10) FROM 
range(2)"),
-                self.spark.sql("SELECT test_udf(a => id, b => id * 10) FROM 
range(2)"),
-                self.spark.sql("SELECT test_udf(b => id * 10, a => id) FROM 
range(2)"),
-            ]
-        ):
-            with self.subTest(with_b=True, query_no=i):
-                self.assertEqual(expected, df.collect())
+        with self.temp_func("test_udf"):
+            self.spark.udf.register("test_udf", test_udf)
+
+            # without "b"
+            expected = [Row(0), Row(1)]
+            for i, df in enumerate(
+                [
+                    self.spark.range(2).select(test_udf(F.col("id"))),
+                    self.spark.range(2).select(test_udf(a=F.col("id"))),
+                    self.spark.sql("SELECT test_udf(id) FROM range(2)"),
+                    self.spark.sql("SELECT test_udf(a => id) FROM range(2)"),
+                ]
+            ):
+                with self.subTest(with_b=False, query_no=i):
+                    self.assertEqual(expected, df.collect())
+
+            # with "b"
+            expected = [Row(0), Row(101)]
+            for i, df in enumerate(
+                [
+                    self.spark.range(2).select(test_udf(F.col("id"), 
b=F.col("id") * 10)),
+                    self.spark.range(2).select(test_udf(a=F.col("id"), 
b=F.col("id") * 10)),
+                    self.spark.range(2).select(test_udf(b=F.col("id") * 10, 
a=F.col("id"))),
+                    self.spark.sql("SELECT test_udf(id, b => id * 10) FROM 
range(2)"),
+                    self.spark.sql("SELECT test_udf(a => id, b => id * 10) 
FROM range(2)"),
+                    self.spark.sql("SELECT test_udf(b => id * 10, a => id) 
FROM range(2)"),
+                ]
+            ):
+                with self.subTest(with_b=True, query_no=i):
+                    self.assertEqual(expected, df.collect())
 
     def test_arrow_udf_kwargs(self):
         import pyarrow as pa
@@ -1067,20 +1071,20 @@ class ScalarArrowUDFTestsMixin:
         def test_udf(a, **kwargs):
             return pa.compute.add(a, pa.compute.multiply(kwargs["b"], 
10)).cast(pa.int32())
 
-        self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS test_udf")
-        self.spark.udf.register("test_udf", test_udf)
-
-        expected = [Row(0), Row(101)]
-        for i, df in enumerate(
-            [
-                self.spark.range(2).select(test_udf(a=F.col("id"), 
b=F.col("id") * 10)),
-                self.spark.range(2).select(test_udf(b=F.col("id") * 10, 
a=F.col("id"))),
-                self.spark.sql("SELECT test_udf(a => id, b => id * 10) FROM 
range(2)"),
-                self.spark.sql("SELECT test_udf(b => id * 10, a => id) FROM 
range(2)"),
-            ]
-        ):
-            with self.subTest(query_no=i):
-                self.assertEqual(expected, df.collect())
+        with self.temp_func("test_udf"):
+            self.spark.udf.register("test_udf", test_udf)
+
+            expected = [Row(0), Row(101)]
+            for i, df in enumerate(
+                [
+                    self.spark.range(2).select(test_udf(a=F.col("id"), 
b=F.col("id") * 10)),
+                    self.spark.range(2).select(test_udf(b=F.col("id") * 10, 
a=F.col("id"))),
+                    self.spark.sql("SELECT test_udf(a => id, b => id * 10) 
FROM range(2)"),
+                    self.spark.sql("SELECT test_udf(b => id * 10, a => id) 
FROM range(2)"),
+                ]
+            ):
+                with self.subTest(query_no=i):
+                    self.assertEqual(expected, df.collect())
 
     def test_arrow_iter_udf_single_column(self):
         import pyarrow as pa
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py 
b/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py
index b3ed4c020ca3..d67b99475bf8 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py
@@ -404,7 +404,7 @@ class WindowArrowUDFTestsMixin:
                         windowed.collect(), df.withColumn("wm", 
sf.mean(df.v).over(w)).collect()
                     )
 
-        with self.tempView("v"):
+        with self.tempView("v"), self.temp_func("weighted_mean"):
             df.createOrReplaceTempView("v")
             self.spark.udf.register("weighted_mean", weighted_mean)
 
@@ -435,7 +435,7 @@ class WindowArrowUDFTestsMixin:
         df = self.data
         weighted_mean = self.arrow_agg_weighted_mean_udf
 
-        with self.tempView("v"):
+        with self.tempView("v"), self.temp_func("weighted_mean"):
             df.createOrReplaceTempView("v")
             self.spark.udf.register("weighted_mean", weighted_mean)
 
@@ -505,7 +505,7 @@ class WindowArrowUDFTestsMixin:
                         windowed.collect(), df.withColumn("wm", 
sf.mean(df.v).over(w)).collect()
                     )
 
-        with self.tempView("v"):
+        with self.tempView("v"), self.temp_func("weighted_mean"):
             df.createOrReplaceTempView("v")
             self.spark.udf.register("weighted_mean", weighted_mean)
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to