This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 748de5f3710e [SPARK-53969][PYTHON][TESTS] Drop temporary functions in
Arrow UDF tests
748de5f3710e is described below
commit 748de5f3710e9401d4ae854ca17fe6060c080caa
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Tue Oct 21 10:38:30 2025 -0700
[SPARK-53969][PYTHON][TESTS] Drop temporary functions in Arrow UDF tests
### What changes were proposed in this pull request?
Drop temporary functions in Arrow UDF tests
### Why are the changes needed?
to avoid the env being polluted
### Does this PR introduce _any_ user-facing change?
no, test-only
### How was this patch tested?
ci
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #52682 from zhengruifeng/with_temp_func_arrow.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
---
.../sql/tests/arrow/test_arrow_python_udf.py | 53 ++--
.../sql/tests/arrow/test_arrow_udf_grouped_agg.py | 31 ++-
.../sql/tests/arrow/test_arrow_udf_scalar.py | 270 +++++++++++----------
.../sql/tests/arrow/test_arrow_udf_window.py | 6 +-
4 files changed, 188 insertions(+), 172 deletions(-)
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py
b/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py
index 8ab02875cb3a..ba1fd3e4e0a9 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py
@@ -128,20 +128,24 @@ class ArrowPythonUDFTestsMixin(BaseUDFTestsMixin):
df = self.spark.range(1).selectExpr(
"array(1, 2, 3) as array",
)
- str_repr_func = self.spark.udf.register("str_repr", udf(lambda x:
str(x), useArrow=True))
- # To verify that Arrow optimization is on
- self.assertIn(
- df.selectExpr("str_repr(array) AS str_id").first()[0],
- ["[1, 2, 3]", "[np.int32(1), np.int32(2), np.int32(3)]"],
- # The input is a NumPy array when the Arrow optimization is on
- )
+ with self.temp_func("str_repr"):
+ str_repr_func = self.spark.udf.register(
+ "str_repr", udf(lambda x: str(x), useArrow=True)
+ )
- # To verify that a UserDefinedFunction is returned
- self.assertListEqual(
- df.selectExpr("str_repr(array) AS str_id").collect(),
- df.select(str_repr_func("array").alias("str_id")).collect(),
- )
+ # To verify that Arrow optimization is on
+ self.assertIn(
+ df.selectExpr("str_repr(array) AS str_id").first()[0],
+ ["[1, 2, 3]", "[np.int32(1), np.int32(2), np.int32(3)]"],
+ # The input is a NumPy array when the Arrow optimization is on
+ )
+
+ # To verify that a UserDefinedFunction is returned
+ self.assertListEqual(
+ df.selectExpr("str_repr(array) AS str_id").collect(),
+ df.select(str_repr_func("array").alias("str_id")).collect(),
+ )
def test_nested_array_input(self):
df = self.spark.range(1).selectExpr("array(array(1, 2), array(3, 4))
as nested_array")
@@ -275,22 +279,23 @@ class ArrowPythonUDFTestsMixin(BaseUDFTestsMixin):
def test_udf(a, b):
return a + b
- self.spark.udf.register("test_udf", test_udf)
+ with self.temp_func("test_udf"):
+ self.spark.udf.register("test_udf", test_udf)
- with self.assertRaisesRegex(
- AnalysisException,
-
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
- ):
- self.spark.sql("SELECT test_udf(a => id, a => id * 10) FROM
range(2)").show()
+ with self.assertRaisesRegex(
+ AnalysisException,
+
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
+ ):
+ self.spark.sql("SELECT test_udf(a => id, a => id * 10) FROM
range(2)").show()
- with self.assertRaisesRegex(AnalysisException,
"UNEXPECTED_POSITIONAL_ARGUMENT"):
- self.spark.sql("SELECT test_udf(a => id, id * 10) FROM
range(2)").show()
+ with self.assertRaisesRegex(AnalysisException,
"UNEXPECTED_POSITIONAL_ARGUMENT"):
+ self.spark.sql("SELECT test_udf(a => id, id * 10) FROM
range(2)").show()
- with self.assertRaises(PythonException):
- self.spark.sql("SELECT test_udf(c => 'x') FROM range(2)").show()
+ with self.assertRaises(PythonException):
+ self.spark.sql("SELECT test_udf(c => 'x') FROM
range(2)").show()
- with self.assertRaises(PythonException):
- self.spark.sql("SELECT test_udf(id, a => id * 10) FROM
range(2)").show()
+ with self.assertRaises(PythonException):
+ self.spark.sql("SELECT test_udf(id, a => id * 10) FROM
range(2)").show()
def test_udf_with_udt(self):
for fallback in [False, True]:
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py
b/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py
index fae9650b2864..136a99e19411 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py
@@ -488,12 +488,19 @@ class GroupedAggArrowUDFTestsMixin:
)
self.assertEqual(sum_arrow_udf.evalType,
PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF)
- group_agg_pandas_udf = self.spark.udf.register("sum_arrow_udf",
sum_arrow_udf)
- self.assertEqual(group_agg_pandas_udf.evalType,
PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF)
- q = "SELECT sum_arrow_udf(v1) FROM VALUES (3, 0), (2, 0), (1, 1)
tbl(v1, v2) GROUP BY v2"
- actual = sorted(map(lambda r: r[0], self.spark.sql(q).collect()))
- expected = [1, 5]
- self.assertEqual(actual, expected)
+
+ with self.temp_func("sum_arrow_udf"):
+ group_agg_pandas_udf = self.spark.udf.register("sum_arrow_udf",
sum_arrow_udf)
+ self.assertEqual(
+ group_agg_pandas_udf.evalType,
PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF
+ )
+ q = """
+ SELECT sum_arrow_udf(v1)
+ FROM VALUES (3, 0), (2, 0), (1, 1) tbl(v1, v2) GROUP BY v2
+ """
+ actual = sorted(map(lambda r: r[0], self.spark.sql(q).collect()))
+ expected = [1, 5]
+ self.assertEqual(actual, expected)
def test_grouped_with_empty_partition(self):
import pyarrow as pa
@@ -516,10 +523,10 @@ class GroupedAggArrowUDFTestsMixin:
return float(pa.compute.max(v).as_py())
df = self.spark.range(0, 100)
- self.spark.udf.register("max_udf", max_udf)
- with self.tempView("table"):
+ with self.tempView("table"), self.temp_func("max_udf"):
df.createTempView("table")
+ self.spark.udf.register("max_udf", max_udf)
agg1 = df.agg(max_udf(df["id"]))
agg2 = self.spark.sql("select max_udf(id) from table")
@@ -546,7 +553,7 @@ class GroupedAggArrowUDFTestsMixin:
df = self.data
weighted_mean = self.arrow_agg_weighted_mean_udf
- with self.tempView("v"):
+ with self.tempView("v"), self.temp_func("weighted_mean"):
df.createOrReplaceTempView("v")
self.spark.udf.register("weighted_mean", weighted_mean)
@@ -575,7 +582,7 @@ class GroupedAggArrowUDFTestsMixin:
df = self.data
weighted_mean = self.arrow_agg_weighted_mean_udf
- with self.tempView("v"):
+ with self.tempView("v"), self.temp_func("weighted_mean"):
df.createOrReplaceTempView("v")
self.spark.udf.register("weighted_mean", weighted_mean)
@@ -615,7 +622,7 @@ class GroupedAggArrowUDFTestsMixin:
return np.average(kwargs["v"], weights=kwargs["w"])
- with self.tempView("v"):
+ with self.tempView("v"), self.temp_func("weighted_mean"):
df.createOrReplaceTempView("v")
self.spark.udf.register("weighted_mean", weighted_mean)
@@ -660,7 +667,7 @@ class GroupedAggArrowUDFTestsMixin:
def biased_sum(v, w=None):
return pa.compute.sum(v).as_py() + (pa.compute.sum(w).as_py() if w
is not None else 100)
- with self.tempView("v"):
+ with self.tempView("v"), self.temp_func("biased_sum"):
df.createOrReplaceTempView("v")
self.spark.udf.register("biased_sum", biased_sum)
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py
b/python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py
index 75ab8e2ff15b..a682c6515ef6 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py
@@ -707,38 +707,38 @@ class ScalarArrowUDFTestsMixin:
self.assertEqual(scalar_original_add.evalType,
PythonEvalType.SQL_SCALAR_ARROW_UDF)
self.assertEqual(scalar_original_add.deterministic, True)
- self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS add1")
- new_add = self.spark.udf.register("add1", scalar_original_add)
+ with self.temp_func("add1"):
+ new_add = self.spark.udf.register("add1", scalar_original_add)
- self.assertEqual(new_add.deterministic, True)
- self.assertEqual(new_add.evalType, PythonEvalType.SQL_SCALAR_ARROW_UDF)
+ self.assertEqual(new_add.deterministic, True)
+ self.assertEqual(new_add.evalType,
PythonEvalType.SQL_SCALAR_ARROW_UDF)
- df = self.spark.range(10).select(
- F.col("id").cast("int").alias("a"),
F.col("id").cast("int").alias("b")
- )
- res1 = df.select(new_add(F.col("a"), F.col("b")))
- res2 = self.spark.sql(
- "SELECT add1(t.a, t.b) FROM (SELECT id as a, id as b FROM
range(10)) t"
- )
- expected = df.select(F.expr("a + b"))
- self.assertEqual(expected.collect(), res1.collect())
- self.assertEqual(expected.collect(), res2.collect())
+ df = self.spark.range(10).select(
+ F.col("id").cast("int").alias("a"),
F.col("id").cast("int").alias("b")
+ )
+ res1 = df.select(new_add(F.col("a"), F.col("b")))
+ res2 = self.spark.sql(
+ "SELECT add1(t.a, t.b) FROM (SELECT id as a, id as b FROM
range(10)) t"
+ )
+ expected = df.select(F.expr("a + b"))
+ self.assertEqual(expected.collect(), res1.collect())
+ self.assertEqual(expected.collect(), res2.collect())
@arrow_udf(LongType())
def scalar_iter_add(it: Iterator[Tuple[pa.Array, pa.Array]]) ->
Iterator[pa.Array]:
for a, b in it:
yield pa.compute.add(a, b)
- self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS add1")
- new_add = self.spark.udf.register("add1", scalar_iter_add)
+ with self.temp_func("add1"):
+ new_add = self.spark.udf.register("add1", scalar_iter_add)
- res3 = df.select(new_add(F.col("a"), F.col("b")))
- res4 = self.spark.sql(
- "SELECT add1(t.a, t.b) FROM (SELECT id as a, id as b FROM
range(10)) t"
- )
- expected = df.select(F.expr("a + b"))
- self.assertEqual(expected.collect(), res3.collect())
- self.assertEqual(expected.collect(), res4.collect())
+ res3 = df.select(new_add(F.col("a"), F.col("b")))
+ res4 = self.spark.sql(
+ "SELECT add1(t.a, t.b) FROM (SELECT id as a, id as b FROM
range(10)) t"
+ )
+ expected = df.select(F.expr("a + b"))
+ self.assertEqual(expected.collect(), res3.collect())
+ self.assertEqual(expected.collect(), res4.collect())
def test_catalog_register_arrow_udf_basic(self):
import pyarrow as pa
@@ -749,38 +749,38 @@ class ScalarArrowUDFTestsMixin:
self.assertEqual(scalar_original_add.evalType,
PythonEvalType.SQL_SCALAR_ARROW_UDF)
self.assertEqual(scalar_original_add.deterministic, True)
- self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS add1")
- new_add = self.spark.catalog.registerFunction("add1",
scalar_original_add)
+ with self.temp_func("add1"):
+ new_add = self.spark.catalog.registerFunction("add1",
scalar_original_add)
- self.assertEqual(new_add.deterministic, True)
- self.assertEqual(new_add.evalType, PythonEvalType.SQL_SCALAR_ARROW_UDF)
+ self.assertEqual(new_add.deterministic, True)
+ self.assertEqual(new_add.evalType,
PythonEvalType.SQL_SCALAR_ARROW_UDF)
- df = self.spark.range(10).select(
- F.col("id").cast("int").alias("a"),
F.col("id").cast("int").alias("b")
- )
- res1 = df.select(new_add(F.col("a"), F.col("b")))
- res2 = self.spark.sql(
- "SELECT add1(t.a, t.b) FROM (SELECT id as a, id as b FROM
range(10)) t"
- )
- expected = df.select(F.expr("a + b"))
- self.assertEqual(expected.collect(), res1.collect())
- self.assertEqual(expected.collect(), res2.collect())
+ df = self.spark.range(10).select(
+ F.col("id").cast("int").alias("a"),
F.col("id").cast("int").alias("b")
+ )
+ res1 = df.select(new_add(F.col("a"), F.col("b")))
+ res2 = self.spark.sql(
+ "SELECT add1(t.a, t.b) FROM (SELECT id as a, id as b FROM
range(10)) t"
+ )
+ expected = df.select(F.expr("a + b"))
+ self.assertEqual(expected.collect(), res1.collect())
+ self.assertEqual(expected.collect(), res2.collect())
@arrow_udf(LongType())
def scalar_iter_add(it: Iterator[Tuple[pa.Array, pa.Array]]) ->
Iterator[pa.Array]:
for a, b in it:
yield pa.compute.add(a, b)
- self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS add1")
- new_add = self.spark.catalog.registerFunction("add1", scalar_iter_add)
+ with self.temp_func("add1"):
+ new_add = self.spark.catalog.registerFunction("add1",
scalar_iter_add)
- res3 = df.select(new_add(F.col("a"), F.col("b")))
- res4 = self.spark.sql(
- "SELECT add1(t.a, t.b) FROM (SELECT id as a, id as b FROM
range(10)) t"
- )
- expected = df.select(F.expr("a + b"))
- self.assertEqual(expected.collect(), res3.collect())
- self.assertEqual(expected.collect(), res4.collect())
+ res3 = df.select(new_add(F.col("a"), F.col("b")))
+ res4 = self.spark.sql(
+ "SELECT add1(t.a, t.b) FROM (SELECT id as a, id as b FROM
range(10)) t"
+ )
+ expected = df.select(F.expr("a + b"))
+ self.assertEqual(expected.collect(), res3.collect())
+ self.assertEqual(expected.collect(), res4.collect())
def test_udf_register_nondeterministic_arrow_udf(self):
import pyarrow as pa
@@ -791,13 +791,15 @@ class ScalarArrowUDFTestsMixin:
self.assertEqual(random_arrow_udf.deterministic, False)
self.assertEqual(random_arrow_udf.evalType,
PythonEvalType.SQL_SCALAR_ARROW_UDF)
- self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS randomArrowUDF")
- nondeterministic_arrow_udf = self.spark.udf.register("randomArrowUDF",
random_arrow_udf)
+ with self.temp_func("randomArrowUDF"):
+ nondeterministic_arrow_udf =
self.spark.udf.register("randomArrowUDF", random_arrow_udf)
- self.assertEqual(nondeterministic_arrow_udf.deterministic, False)
- self.assertEqual(nondeterministic_arrow_udf.evalType,
PythonEvalType.SQL_SCALAR_ARROW_UDF)
- [row] = self.spark.sql("SELECT randomArrowUDF(1)").collect()
- self.assertEqual(row[0], 7)
+ self.assertEqual(nondeterministic_arrow_udf.deterministic, False)
+ self.assertEqual(
+ nondeterministic_arrow_udf.evalType,
PythonEvalType.SQL_SCALAR_ARROW_UDF
+ )
+ [row] = self.spark.sql("SELECT randomArrowUDF(1)").collect()
+ self.assertEqual(row[0], 7)
def test_catalog_register_nondeterministic_arrow_udf(self):
import pyarrow as pa
@@ -808,15 +810,17 @@ class ScalarArrowUDFTestsMixin:
self.assertEqual(random_arrow_udf.deterministic, False)
self.assertEqual(random_arrow_udf.evalType,
PythonEvalType.SQL_SCALAR_ARROW_UDF)
- self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS randomArrowUDF")
- nondeterministic_arrow_udf = self.spark.catalog.registerFunction(
- "randomArrowUDF", random_arrow_udf
- )
+ with self.temp_func("randomArrowUDF"):
+ nondeterministic_arrow_udf = self.spark.catalog.registerFunction(
+ "randomArrowUDF", random_arrow_udf
+ )
- self.assertEqual(nondeterministic_arrow_udf.deterministic, False)
- self.assertEqual(nondeterministic_arrow_udf.evalType,
PythonEvalType.SQL_SCALAR_ARROW_UDF)
- [row] = self.spark.sql("SELECT randomArrowUDF(1)").collect()
- self.assertEqual(row[0], 7)
+ self.assertEqual(nondeterministic_arrow_udf.deterministic, False)
+ self.assertEqual(
+ nondeterministic_arrow_udf.evalType,
PythonEvalType.SQL_SCALAR_ARROW_UDF
+ )
+ [row] = self.spark.sql("SELECT randomArrowUDF(1)").collect()
+ self.assertEqual(row[0], 7)
@unittest.skipIf(not have_numpy, numpy_requirement_message)
def test_nondeterministic_arrow_udf(self):
@@ -981,22 +985,22 @@ class ScalarArrowUDFTestsMixin:
def test_udf(a, b):
return pa.compute.add(a, pa.compute.multiply(b,
10)).cast(pa.int32())
- self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS test_udf")
- self.spark.udf.register("test_udf", test_udf)
-
- expected = [Row(0), Row(101)]
- for i, df in enumerate(
- [
- self.spark.range(2).select(test_udf(F.col("id"), b=F.col("id")
* 10)),
- self.spark.range(2).select(test_udf(a=F.col("id"),
b=F.col("id") * 10)),
- self.spark.range(2).select(test_udf(b=F.col("id") * 10,
a=F.col("id"))),
- self.spark.sql("SELECT test_udf(id, b => id * 10) FROM
range(2)"),
- self.spark.sql("SELECT test_udf(a => id, b => id * 10) FROM
range(2)"),
- self.spark.sql("SELECT test_udf(b => id * 10, a => id) FROM
range(2)"),
- ]
- ):
- with self.subTest(query_no=i):
- self.assertEqual(expected, df.collect())
+ with self.temp_func("test_udf"):
+ self.spark.udf.register("test_udf", test_udf)
+
+ expected = [Row(0), Row(101)]
+ for i, df in enumerate(
+ [
+ self.spark.range(2).select(test_udf(F.col("id"),
b=F.col("id") * 10)),
+ self.spark.range(2).select(test_udf(a=F.col("id"),
b=F.col("id") * 10)),
+ self.spark.range(2).select(test_udf(b=F.col("id") * 10,
a=F.col("id"))),
+ self.spark.sql("SELECT test_udf(id, b => id * 10) FROM
range(2)"),
+ self.spark.sql("SELECT test_udf(a => id, b => id * 10)
FROM range(2)"),
+ self.spark.sql("SELECT test_udf(b => id * 10, a => id)
FROM range(2)"),
+ ]
+ ):
+ with self.subTest(query_no=i):
+ self.assertEqual(expected, df.collect())
def test_arrow_udf_named_arguments_negative(self):
import pyarrow as pa
@@ -1005,22 +1009,22 @@ class ScalarArrowUDFTestsMixin:
def test_udf(a, b):
return pa.compute.add(a, b).cast(pa.int32())
- self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS test_udf")
- self.spark.udf.register("test_udf", test_udf)
+ with self.temp_func("test_udf"):
+ self.spark.udf.register("test_udf", test_udf)
- with self.assertRaisesRegex(
- AnalysisException,
-
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
- ):
- self.spark.sql("SELECT test_udf(a => id, a => id * 10) FROM
range(2)").show()
+ with self.assertRaisesRegex(
+ AnalysisException,
+
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
+ ):
+ self.spark.sql("SELECT test_udf(a => id, a => id * 10) FROM
range(2)").show()
- with self.assertRaisesRegex(AnalysisException,
"UNEXPECTED_POSITIONAL_ARGUMENT"):
- self.spark.sql("SELECT test_udf(a => id, id * 10) FROM
range(2)").show()
+ with self.assertRaisesRegex(AnalysisException,
"UNEXPECTED_POSITIONAL_ARGUMENT"):
+ self.spark.sql("SELECT test_udf(a => id, id * 10) FROM
range(2)").show()
- with self.assertRaisesRegex(
- PythonException, r"test_udf\(\) got an unexpected keyword argument
'c'"
- ):
- self.spark.sql("SELECT test_udf(c => 'x') FROM range(2)").show()
+ with self.assertRaisesRegex(
+ PythonException, r"test_udf\(\) got an unexpected keyword
argument 'c'"
+ ):
+ self.spark.sql("SELECT test_udf(c => 'x') FROM
range(2)").show()
def test_arrow_udf_named_arguments_and_defaults(self):
import pyarrow as pa
@@ -1029,36 +1033,36 @@ class ScalarArrowUDFTestsMixin:
def test_udf(a, b=0):
return pa.compute.add(a, pa.compute.multiply(b,
10)).cast(pa.int32())
- self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS test_udf")
- self.spark.udf.register("test_udf", test_udf)
-
- # without "b"
- expected = [Row(0), Row(1)]
- for i, df in enumerate(
- [
- self.spark.range(2).select(test_udf(F.col("id"))),
- self.spark.range(2).select(test_udf(a=F.col("id"))),
- self.spark.sql("SELECT test_udf(id) FROM range(2)"),
- self.spark.sql("SELECT test_udf(a => id) FROM range(2)"),
- ]
- ):
- with self.subTest(with_b=False, query_no=i):
- self.assertEqual(expected, df.collect())
-
- # with "b"
- expected = [Row(0), Row(101)]
- for i, df in enumerate(
- [
- self.spark.range(2).select(test_udf(F.col("id"), b=F.col("id")
* 10)),
- self.spark.range(2).select(test_udf(a=F.col("id"),
b=F.col("id") * 10)),
- self.spark.range(2).select(test_udf(b=F.col("id") * 10,
a=F.col("id"))),
- self.spark.sql("SELECT test_udf(id, b => id * 10) FROM
range(2)"),
- self.spark.sql("SELECT test_udf(a => id, b => id * 10) FROM
range(2)"),
- self.spark.sql("SELECT test_udf(b => id * 10, a => id) FROM
range(2)"),
- ]
- ):
- with self.subTest(with_b=True, query_no=i):
- self.assertEqual(expected, df.collect())
+ with self.temp_func("test_udf"):
+ self.spark.udf.register("test_udf", test_udf)
+
+ # without "b"
+ expected = [Row(0), Row(1)]
+ for i, df in enumerate(
+ [
+ self.spark.range(2).select(test_udf(F.col("id"))),
+ self.spark.range(2).select(test_udf(a=F.col("id"))),
+ self.spark.sql("SELECT test_udf(id) FROM range(2)"),
+ self.spark.sql("SELECT test_udf(a => id) FROM range(2)"),
+ ]
+ ):
+ with self.subTest(with_b=False, query_no=i):
+ self.assertEqual(expected, df.collect())
+
+ # with "b"
+ expected = [Row(0), Row(101)]
+ for i, df in enumerate(
+ [
+ self.spark.range(2).select(test_udf(F.col("id"),
b=F.col("id") * 10)),
+ self.spark.range(2).select(test_udf(a=F.col("id"),
b=F.col("id") * 10)),
+ self.spark.range(2).select(test_udf(b=F.col("id") * 10,
a=F.col("id"))),
+ self.spark.sql("SELECT test_udf(id, b => id * 10) FROM
range(2)"),
+ self.spark.sql("SELECT test_udf(a => id, b => id * 10)
FROM range(2)"),
+ self.spark.sql("SELECT test_udf(b => id * 10, a => id)
FROM range(2)"),
+ ]
+ ):
+ with self.subTest(with_b=True, query_no=i):
+ self.assertEqual(expected, df.collect())
def test_arrow_udf_kwargs(self):
import pyarrow as pa
@@ -1067,20 +1071,20 @@ class ScalarArrowUDFTestsMixin:
def test_udf(a, **kwargs):
return pa.compute.add(a, pa.compute.multiply(kwargs["b"],
10)).cast(pa.int32())
- self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS test_udf")
- self.spark.udf.register("test_udf", test_udf)
-
- expected = [Row(0), Row(101)]
- for i, df in enumerate(
- [
- self.spark.range(2).select(test_udf(a=F.col("id"),
b=F.col("id") * 10)),
- self.spark.range(2).select(test_udf(b=F.col("id") * 10,
a=F.col("id"))),
- self.spark.sql("SELECT test_udf(a => id, b => id * 10) FROM
range(2)"),
- self.spark.sql("SELECT test_udf(b => id * 10, a => id) FROM
range(2)"),
- ]
- ):
- with self.subTest(query_no=i):
- self.assertEqual(expected, df.collect())
+ with self.temp_func("test_udf"):
+ self.spark.udf.register("test_udf", test_udf)
+
+ expected = [Row(0), Row(101)]
+ for i, df in enumerate(
+ [
+ self.spark.range(2).select(test_udf(a=F.col("id"),
b=F.col("id") * 10)),
+ self.spark.range(2).select(test_udf(b=F.col("id") * 10,
a=F.col("id"))),
+ self.spark.sql("SELECT test_udf(a => id, b => id * 10)
FROM range(2)"),
+ self.spark.sql("SELECT test_udf(b => id * 10, a => id)
FROM range(2)"),
+ ]
+ ):
+ with self.subTest(query_no=i):
+ self.assertEqual(expected, df.collect())
def test_arrow_iter_udf_single_column(self):
import pyarrow as pa
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py
b/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py
index b3ed4c020ca3..d67b99475bf8 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py
@@ -404,7 +404,7 @@ class WindowArrowUDFTestsMixin:
windowed.collect(), df.withColumn("wm",
sf.mean(df.v).over(w)).collect()
)
- with self.tempView("v"):
+ with self.tempView("v"), self.temp_func("weighted_mean"):
df.createOrReplaceTempView("v")
self.spark.udf.register("weighted_mean", weighted_mean)
@@ -435,7 +435,7 @@ class WindowArrowUDFTestsMixin:
df = self.data
weighted_mean = self.arrow_agg_weighted_mean_udf
- with self.tempView("v"):
+ with self.tempView("v"), self.temp_func("weighted_mean"):
df.createOrReplaceTempView("v")
self.spark.udf.register("weighted_mean", weighted_mean)
@@ -505,7 +505,7 @@ class WindowArrowUDFTestsMixin:
windowed.collect(), df.withColumn("wm",
sf.mean(df.v).over(w)).collect()
)
- with self.tempView("v"):
+ with self.tempView("v"), self.temp_func("weighted_mean"):
df.createOrReplaceTempView("v")
self.spark.udf.register("weighted_mean", weighted_mean)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]