This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 686d84453610 [SPARK-53592][PYTHON] Make `@udf` support vectorized UDF
686d84453610 is described below
commit 686d84453610e463df7df95395ce6ed36a6efacd
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Sat Sep 20 20:19:31 2025 +0800
[SPARK-53592][PYTHON] Make `@udf` support vectorized UDF
<!--
Thanks for sending a pull request! Here are some tips for you:
1. If this is your first time, please read our contributor guidelines:
https://spark.apache.org/contributing.html
2. Ensure you have added or run the appropriate tests for your PR:
https://spark.apache.org/developer-tools.html
3. If the PR is unfinished, add '[WIP]' in your PR title, e.g.,
'[WIP][SPARK-XXXX] Your PR title ...'.
4. Be sure to keep the PR description updated to reflect all changes.
5. Please write your PR title to summarize what this PR proposes.
6. If possible, provide a concise example to reproduce the issue for a
faster review.
7. If you want to add a new configuration, please read the guideline
first for naming configurations in
'core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala'.
8. If you want to add or modify an error type or message, please read the
guideline first in
'common/utils/src/main/resources/error/README.md'.
-->
### What changes were proposed in this pull request?
Make udf support vectorized UDF
### Why are the changes needed?
to prompt vectorized UDF
### Does this PR introduce _any_ user-facing change?
`udf` will try to infer the eval type based on the type hints
For example,
```python
udf(returnType=LongType())
def pd_add1(ser: pd.Series) -> pd.Series:
assert isinstance(ser, pd.Series)
return ser + 1
```
The inferred type is `PythonEvalType.SQL_SCALAR_PANDAS_UDF`
### How was this patch tested?
added UTs
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #52323 from zhengruifeng/unify_udf.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
dev/sparktestsupport/modules.py | 2 +
python/pyspark/sql/connect/udf.py | 22 +-
python/pyspark/sql/pandas/typehints.py | 26 +-
.../sql/tests/connect/test_parity_unified_udf.py | 40 ++
python/pyspark/sql/tests/test_unified_udf.py | 440 +++++++++++++++++++++
python/pyspark/sql/udf.py | 22 +-
6 files changed, 544 insertions(+), 8 deletions(-)
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index f26a70d68d03..a8bbf6c0eef6 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -587,6 +587,7 @@ pyspark_sql = Module(
"pyspark.sql.tests.test_udf",
"pyspark.sql.tests.test_udf_combinations",
"pyspark.sql.tests.test_udf_profiler",
+ "pyspark.sql.tests.test_unified_udf",
"pyspark.sql.tests.test_udtf",
"pyspark.sql.tests.test_tvf",
"pyspark.sql.tests.test_utils",
@@ -1107,6 +1108,7 @@ pyspark_connect = Module(
"pyspark.sql.tests.connect.test_parity_udf",
"pyspark.sql.tests.connect.test_parity_udf_combinations",
"pyspark.sql.tests.connect.test_parity_udf_profiler",
+ "pyspark.sql.tests.connect.test_parity_unified_udf",
"pyspark.sql.tests.connect.test_parity_memory_profiler",
"pyspark.sql.tests.connect.test_parity_udtf",
"pyspark.sql.tests.connect.test_parity_tvf",
diff --git a/python/pyspark/sql/connect/udf.py
b/python/pyspark/sql/connect/udf.py
index 0f45690a9db3..fc5a4c79d8ad 100644
--- a/python/pyspark/sql/connect/udf.py
+++ b/python/pyspark/sql/connect/udf.py
@@ -77,10 +77,7 @@ def _create_py_udf(
else:
is_arrow_enabled = useArrow
- eval_type: int = PythonEvalType.SQL_BATCHED_UDF
-
if is_arrow_enabled:
- eval_type = PythonEvalType.SQL_ARROW_BATCHED_UDF
try:
require_minimum_pandas_version()
require_minimum_pyarrow_version()
@@ -92,6 +89,25 @@ def _create_py_udf(
RuntimeWarning,
)
+ eval_type: Optional[int] = None
+ if useArrow is None:
+ # If the user doesn't explicitly set useArrow
+ from pyspark.sql.pandas.typehints import infer_eval_type_from_func
+
+ try:
+ # Try to infer the eval type from type hints
+ eval_type = infer_eval_type_from_func(f)
+ except Exception:
+ warnings.warn("Cannot infer the eval type from type hints. ",
UserWarning)
+
+ if eval_type is None:
+ if is_arrow_enabled:
+ # Arrow optimized Python UDF
+ eval_type = PythonEvalType.SQL_ARROW_BATCHED_UDF
+ else:
+ # Fallback to Regular Python UDF
+ eval_type = PythonEvalType.SQL_BATCHED_UDF
+
return _create_udf(f, returnType, eval_type)
diff --git a/python/pyspark/sql/pandas/typehints.py
b/python/pyspark/sql/pandas/typehints.py
index f010489b9512..4252060f8b22 100644
--- a/python/pyspark/sql/pandas/typehints.py
+++ b/python/pyspark/sql/pandas/typehints.py
@@ -15,7 +15,8 @@
# limitations under the License.
#
from inspect import Signature
-from typing import Any, Callable, Dict, Optional, Union, TYPE_CHECKING
+from typing import Any, Callable, Dict, Optional, Union, TYPE_CHECKING,
get_type_hints
+from inspect import getfullargspec, signature
from pyspark.sql.pandas.utils import require_minimum_pandas_version,
require_minimum_pyarrow_version
from pyspark.errors import PySparkNotImplementedError, PySparkValueError
@@ -277,6 +278,29 @@ def infer_eval_type(
return eval_type
+def infer_eval_type_from_func( # type: ignore[no-untyped-def]
+ f,
+) -> Optional[
+ Union[
+ "PandasScalarUDFType",
+ "PandasScalarIterUDFType",
+ "PandasGroupedAggUDFType",
+ "ArrowScalarUDFType",
+ "ArrowScalarIterUDFType",
+ "ArrowGroupedAggUDFType",
+ ]
+]:
+ argspec = getfullargspec(f)
+ if len(argspec.annotations) > 0:
+ try:
+ type_hints = get_type_hints(f)
+ except NameError:
+ type_hints = {}
+ return infer_eval_type(signature(f), type_hints)
+ else:
+ return None
+
+
def check_tuple_annotation(
annotation: Any, parameter_check_func: Optional[Callable[[Any], bool]] =
None
) -> bool:
diff --git a/python/pyspark/sql/tests/connect/test_parity_unified_udf.py
b/python/pyspark/sql/tests/connect/test_parity_unified_udf.py
new file mode 100644
index 000000000000..8c076f173c95
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/test_parity_unified_udf.py
@@ -0,0 +1,40 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from pyspark.sql.tests.test_unified_udf import UnifiedUDFTestsMixin
+from pyspark.testing.connectutils import ReusedConnectTestCase
+
+
+class UnifiedUDFParityTests(UnifiedUDFTestsMixin, ReusedConnectTestCase):
+ @classmethod
+ def setUpClass(cls):
+ ReusedConnectTestCase.setUpClass()
+ cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled",
"false")
+
+
+if __name__ == "__main__":
+ from pyspark.sql.tests.connect.test_parity_unified_udf import * # noqa:
F401
+
+ try:
+ import xmlrunner # type: ignore[import]
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/test_unified_udf.py
b/python/pyspark/sql/tests/test_unified_udf.py
new file mode 100644
index 000000000000..d74e404d7528
--- /dev/null
+++ b/python/pyspark/sql/tests/test_unified_udf.py
@@ -0,0 +1,440 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from typing import Iterator, Tuple
+
+from pyspark.sql import functions as sf
+from pyspark.sql.window import Window
+from pyspark.sql.functions import udf
+from pyspark.sql.types import LongType
+from pyspark.testing.utils import (
+ have_pandas,
+ have_pyarrow,
+ pandas_requirement_message,
+ pyarrow_requirement_message,
+)
+from pyspark.testing.sqlutils import ReusedSQLTestCase
+from pyspark.util import PythonEvalType
+
+
[email protected](
+ not have_pandas or not have_pyarrow,
+ pandas_requirement_message or pyarrow_requirement_message,
+)
+class UnifiedUDFTestsMixin:
+ def test_scalar_pandas_udf(self):
+ import pandas as pd
+
+ @udf(returnType=LongType())
+ def pd_add1(ser: pd.Series) -> pd.Series:
+ assert isinstance(ser, pd.Series)
+ return ser + 1
+
+ self.assertEqual(pd_add1.evalType,
PythonEvalType.SQL_SCALAR_PANDAS_UDF)
+
+ df = self.spark.range(0, 10)
+ expected = df.select((df.id + 1).alias("res")).collect()
+
+ result1 = df.select(pd_add1("id").alias("res")).collect()
+ self.assertEqual(result1, expected)
+
+ self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS pd_add1")
+ self.spark.udf.register("pd_add1", pd_add1)
+ result2 = self.spark.sql("SELECT pd_add1(id) AS res FROM range(0,
10)").collect()
+ self.assertEqual(result2, expected)
+ self.spark.sql("DROP TEMPORARY FUNCTION pd_add1")
+
+ def test_scalar_pandas_udf_II(self):
+ import pandas as pd
+
+ @udf(returnType=LongType())
+ def pd_add(ser1: pd.Series, ser2: pd.Series) -> pd.Series:
+ assert isinstance(ser1, pd.Series)
+ assert isinstance(ser2, pd.Series)
+ return ser1 + ser2
+
+ self.assertEqual(pd_add.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
+
+ df = self.spark.range(0, 10)
+ expected = df.select((df.id + df.id).alias("res")).collect()
+
+ result1 = df.select(pd_add("id", "id").alias("res")).collect()
+ self.assertEqual(result1, expected)
+
+ self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS pd_add")
+ self.spark.udf.register("pd_add", pd_add)
+ result2 = self.spark.sql("SELECT pd_add(id, id) AS res FROM range(0,
10)").collect()
+ self.assertEqual(result2, expected)
+ self.spark.sql("DROP TEMPORARY FUNCTION pd_add")
+
+ def test_scalar_pandas_iter_udf(self):
+ import pandas as pd
+
+ @udf(returnType=LongType())
+ def pd_add1_iter(it: Iterator[pd.Series]) -> Iterator[pd.Series]:
+ for ser in it:
+ assert isinstance(ser, pd.Series)
+ yield ser + 1
+
+ self.assertEqual(pd_add1_iter.evalType,
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF)
+
+ df = self.spark.range(0, 10)
+ expected = df.select((df.id + 1).alias("res")).collect()
+
+ result1 = df.select(pd_add1_iter("id").alias("res")).collect()
+ self.assertEqual(result1, expected)
+
+ self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS pd_add1_iter")
+ self.spark.udf.register("pd_add1_iter", pd_add1_iter)
+ result2 = self.spark.sql("SELECT pd_add1_iter(id) AS res FROM range(0,
10)").collect()
+ self.assertEqual(result2, expected)
+ self.spark.sql("DROP TEMPORARY FUNCTION pd_add1_iter")
+
+ def test_scalar_pandas_iter_udf_II(self):
+ import pandas as pd
+
+ @udf(returnType=LongType())
+ def pd_add_iter(it: Iterator[Tuple[pd.Series, pd.Series]]) ->
Iterator[pd.Series]:
+ for ser1, ser2 in it:
+ assert isinstance(ser1, pd.Series)
+ assert isinstance(ser2, pd.Series)
+ yield ser1 + ser2
+
+ self.assertEqual(pd_add_iter.evalType,
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF)
+
+ df = self.spark.range(0, 10)
+ expected = df.select((df.id + df.id).alias("res")).collect()
+
+ result1 = df.select(pd_add_iter("id", "id").alias("res")).collect()
+ self.assertEqual(result1, expected)
+
+ self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS pd_add_iter")
+ self.spark.udf.register("pd_add_iter", pd_add_iter)
+ result2 = self.spark.sql("SELECT pd_add_iter(id, id) AS res FROM
range(0, 10)").collect()
+ self.assertEqual(result2, expected)
+ self.spark.sql("DROP TEMPORARY FUNCTION pd_add_iter")
+
+ def test_grouped_agg_pandas_udf(self):
+ import pandas as pd
+
+ @udf(returnType=LongType())
+ def pd_max(ser: pd.Series) -> int:
+ assert isinstance(ser, pd.Series)
+ return ser.max()
+
+ self.assertEqual(pd_max.evalType,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF)
+
+ df = self.spark.range(0, 10)
+ expected = df.select(sf.max("id").alias("res")).collect()
+
+ result1 = df.select(pd_max("id").alias("res")).collect()
+ self.assertEqual(result1, expected)
+
+ self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS pd_max")
+ self.spark.udf.register("pd_max", pd_max)
+ result2 = self.spark.sql("SELECT pd_max(id) AS res FROM range(0,
10)").collect()
+ self.assertEqual(result2, expected)
+ self.spark.sql("DROP TEMPORARY FUNCTION pd_max")
+
+ def test_window_agg_pandas_udf(self):
+ import pandas as pd
+
+ @udf(returnType=LongType())
+ def pd_win_max(ser: pd.Series) -> int:
+ assert isinstance(ser, pd.Series)
+ return ser.max()
+
+ self.assertEqual(pd_win_max.evalType,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF)
+
+ df = (
+ self.spark.range(10)
+ .withColumn("vs", sf.array([sf.lit(i * 1.0) + sf.col("id") for i
in range(20, 30)]))
+ .withColumn("v", sf.explode("vs"))
+ .drop("vs")
+ .withColumn("w", sf.lit(1.0))
+ )
+
+ w = (
+ Window.partitionBy("id")
+ .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
+ .orderBy("v")
+ )
+
+ expected = df.withColumn("res", sf.max("v").over(w)).collect()
+
+ result1 = df.withColumn("res", pd_win_max("v").over(w)).collect()
+ self.assertEqual(result1, expected)
+
+ with self.tempView("pd_tbl"):
+ df.createOrReplaceTempView("pd_tbl")
+ self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS pd_win_max")
+ self.spark.udf.register("pd_win_max", pd_win_max)
+
+ result2 = self.spark.sql(
+ """
+ SELECT *, pd_win_max(v) OVER (
+ PARTITION BY id
+ ORDER BY v
+ ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
+ ) AS res FROM pd_tbl
+ """
+ ).collect()
+ self.assertEqual(result2, expected)
+ self.spark.sql("DROP TEMPORARY FUNCTION pd_win_max")
+
+ def test_scalar_arrow_udf(self):
+ import pyarrow as pa
+
+ @udf(returnType=LongType())
+ def pa_add1(arr: pa.Array) -> pa.Array:
+ assert isinstance(arr, pa.Array)
+ return pa.compute.add(arr, 1)
+
+ self.assertEqual(pa_add1.evalType, PythonEvalType.SQL_SCALAR_ARROW_UDF)
+
+ df = self.spark.range(0, 10)
+ expected = df.select((df.id + 1).alias("res")).collect()
+
+ result1 = df.select(pa_add1("id").alias("res")).collect()
+ self.assertEqual(result1, expected)
+
+ self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS pa_add1")
+ self.spark.udf.register("pa_add1", pa_add1)
+ result2 = self.spark.sql("SELECT pa_add1(id) AS res FROM range(0,
10)").collect()
+ self.assertEqual(result2, expected)
+ self.spark.sql("DROP TEMPORARY FUNCTION pa_add1")
+
+ def test_scalar_arrow_udf_II(self):
+ import pyarrow as pa
+
+ @udf(returnType=LongType())
+ def pa_add(arr1: pa.Array, arr2: pa.Array) -> pa.Array:
+ assert isinstance(arr1, pa.Array)
+ assert isinstance(arr2, pa.Array)
+ return pa.compute.add(arr1, arr2)
+
+ self.assertEqual(pa_add.evalType, PythonEvalType.SQL_SCALAR_ARROW_UDF)
+
+ df = self.spark.range(0, 10)
+ expected = df.select((df.id + df.id).alias("res")).collect()
+
+ result1 = df.select(pa_add("id", "id").alias("res")).collect()
+ self.assertEqual(result1, expected)
+
+ self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS pa_add")
+ self.spark.udf.register("pa_add", pa_add)
+ result2 = self.spark.sql("SELECT pa_add(id, id) AS res FROM range(0,
10)").collect()
+ self.assertEqual(result2, expected)
+ self.spark.sql("DROP TEMPORARY FUNCTION pa_add")
+
+ def test_scalar_arrow_iter_udf(self):
+ import pyarrow as pa
+
+ @udf(returnType=LongType())
+ def pa_add1_iter(it: Iterator[pa.Array]) -> Iterator[pa.Array]:
+ for arr in it:
+ assert isinstance(arr, pa.Array)
+ yield pa.compute.add(arr, 1)
+
+ self.assertEqual(pa_add1_iter.evalType,
PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF)
+
+ df = self.spark.range(0, 10)
+ expected = df.select((df.id + 1).alias("res")).collect()
+
+ result1 = df.select(pa_add1_iter("id").alias("res")).collect()
+ self.assertEqual(result1, expected)
+
+ self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS pa_add1_iter")
+ self.spark.udf.register("pa_add1_iter", pa_add1_iter)
+ result2 = self.spark.sql("SELECT pa_add1_iter(id) AS res FROM range(0,
10)").collect()
+ self.assertEqual(result2, expected)
+ self.spark.sql("DROP TEMPORARY FUNCTION pa_add1_iter")
+
+ def test_scalar_arrow_iter_udf_II(self):
+ import pyarrow as pa
+
+ @udf(returnType=LongType())
+ def pa_add_iter(it: Iterator[Tuple[pa.Array, pa.Array]]) ->
Iterator[pa.Array]:
+ for arr1, arr2 in it:
+ assert isinstance(arr1, pa.Array)
+ assert isinstance(arr2, pa.Array)
+ yield pa.compute.add(arr1, arr2)
+
+ self.assertEqual(pa_add_iter.evalType,
PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF)
+
+ df = self.spark.range(0, 10)
+ expected = df.select((df.id + df.id).alias("res")).collect()
+
+ result1 = df.select(pa_add_iter("id", "id").alias("res")).collect()
+ self.assertEqual(result1, expected)
+
+ self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS pa_add_iter")
+ self.spark.udf.register("pa_add_iter", pa_add_iter)
+ result2 = self.spark.sql("SELECT pa_add_iter(id, id) AS res FROM
range(0, 10)").collect()
+ self.assertEqual(result2, expected)
+ self.spark.sql("DROP TEMPORARY FUNCTION pa_add_iter")
+
+ def test_grouped_agg_arrow_udf(self):
+ import pyarrow as pa
+
+ @udf(returnType=LongType())
+ def pa_max(arr: pa.Array) -> pa.Scalar:
+ assert isinstance(arr, pa.Array)
+ return pa.compute.max(arr)
+
+ self.assertEqual(pa_max.evalType,
PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF)
+
+ df = self.spark.range(0, 10)
+ expected = df.select(sf.max("id").alias("res")).collect()
+
+ result1 = df.select(pa_max("id").alias("res")).collect()
+ self.assertEqual(result1, expected)
+
+ self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS pa_max")
+ self.spark.udf.register("pa_max", pa_max)
+ result2 = self.spark.sql("SELECT pa_max(id) AS res FROM range(0,
10)").collect()
+ self.assertEqual(result2, expected)
+ self.spark.sql("DROP TEMPORARY FUNCTION pa_max")
+
+ def test_window_agg_arrow_udf(self):
+ import pyarrow as pa
+
+ @udf(returnType=LongType())
+ def pa_win_max(arr: pa.Array) -> pa.Scalar:
+ assert isinstance(arr, pa.Array)
+ return pa.compute.max(arr)
+
+ self.assertEqual(pa_win_max.evalType,
PythonEvalType.SQL_GROUPED_AGG_ARROW_UDF)
+
+ df = (
+ self.spark.range(10)
+ .withColumn("vs", sf.array([sf.lit(i * 1.0) + sf.col("id") for i
in range(20, 30)]))
+ .withColumn("v", sf.explode("vs"))
+ .drop("vs")
+ .withColumn("w", sf.lit(1.0))
+ )
+
+ w = (
+ Window.partitionBy("id")
+ .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
+ .orderBy("v")
+ )
+
+ expected = df.withColumn("mean_v", sf.max("v").over(w)).collect()
+
+ result1 = df.withColumn("mean_v", pa_win_max("v").over(w)).collect()
+ self.assertEqual(result1, expected)
+
+ with self.tempView("pa_tbl"):
+ df.createOrReplaceTempView("pa_tbl")
+ self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS pa_win_max")
+ self.spark.udf.register("pa_win_max", pa_win_max)
+
+ result2 = self.spark.sql(
+ """
+ SELECT *, pa_win_max(v) OVER (
+ PARTITION BY id
+ ORDER BY v
+ ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
+ ) AS res FROM pa_tbl
+ """
+ ).collect()
+ self.assertEqual(result2, expected)
+ self.spark.sql("DROP TEMPORARY FUNCTION pa_win_max")
+
+ def test_regular_python_udf(self):
+ import pandas as pd
+ import pyarrow as pa
+
+ @udf(returnType=LongType())
+ def f1(x):
+ return x + 1
+
+ @udf(returnType=LongType())
+ def f2(x: int) -> int:
+ return x + 1
+
+ # Cannot infer a vectorized UDF type
+ @udf(returnType=LongType())
+ def f3(x: int) -> pd.Series:
+ return x + 1
+
+ # Cannot infer a vectorized UDF type
+ @udf(returnType=LongType())
+ def f4(x: int) -> pa.Array:
+ return x + 1
+
+ # useArrow is explicitly set to false
+ @udf(returnType=LongType(), useArrow=False)
+ def f5(x: pd.Series) -> pd.Series:
+ return x + 1
+
+ # useArrow is explicitly set to false
+ @udf(returnType=LongType(), useArrow=False)
+ def f6(x: pa.Array) -> pa.Array:
+ return x + 1
+
+ expected = self.spark.range(10).select((sf.col("id") +
1).alias("res")).collect()
+ for f in [f1, f2, f3, f4, f5, f6]:
+ self.assertEqual(f.evalType, PythonEvalType.SQL_BATCHED_UDF)
+ result =
self.spark.range(10).select(f("id").alias("res")).collect()
+ self.assertEqual(result, expected)
+
+ def test_arrow_optimized_python_udf(self):
+ import pandas as pd
+ import pyarrow as pa
+
+ @udf(returnType=LongType(), useArrow=True)
+ def f1(x):
+ return x + 1
+
+ @udf(returnType=LongType(), useArrow=True)
+ def f2(x: int) -> int:
+ return x + 1
+
+ # useArrow is explicitly set
+ @udf(returnType=LongType(), useArrow=True)
+ def f3(x: pd.Series) -> pd.Series:
+ return x + 1
+
+ # useArrow is explicitly set
+ @udf(returnType=LongType(), useArrow=True)
+ def f4(x: pa.Array) -> pa.Array:
+ return x + 1
+
+ expected = self.spark.range(10).select((sf.col("id") +
1).alias("res")).collect()
+ for f in [f1, f2, f3, f4]:
+ self.assertEqual(f.evalType, PythonEvalType.SQL_ARROW_BATCHED_UDF)
+ result =
self.spark.range(10).select(f("id").alias("res")).collect()
+ self.assertEqual(result, expected)
+
+
+class UnifiedUDFTests(UnifiedUDFTestsMixin, ReusedSQLTestCase):
+ pass
+
+
+if __name__ == "__main__":
+ from pyspark.sql.tests.test_unified_udf import * # noqa: F401
+
+ try:
+ import xmlrunner # type: ignore
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py
index 6785e50cf351..4e37ec88846b 100644
--- a/python/pyspark/sql/udf.py
+++ b/python/pyspark/sql/udf.py
@@ -96,8 +96,8 @@ def _create_py_udf(
# Arrow and Pickle have different type coercion rules, so a UDF might have
a different result
# with/without Arrow optimization. That's the main reason the Arrow
optimization for Python
# UDFs is disabled by default.
- is_arrow_enabled = False
+ is_arrow_enabled = False
if useArrow is None:
from pyspark.sql import SparkSession
@@ -122,10 +122,24 @@ def _create_py_udf(
RuntimeWarning,
)
- eval_type: int = PythonEvalType.SQL_BATCHED_UDF
+ eval_type: Optional[int] = None
+ if useArrow is None:
+ # If the user doesn't explicitly set useArrow
+ from pyspark.sql.pandas.typehints import infer_eval_type_from_func
- if is_arrow_enabled:
- eval_type = PythonEvalType.SQL_ARROW_BATCHED_UDF
+ try:
+ # Try to infer the eval type from type hints
+ eval_type = infer_eval_type_from_func(f)
+ except Exception:
+ warnings.warn("Cannot infer the eval type from type hints. ",
UserWarning)
+
+ if eval_type is None:
+ if is_arrow_enabled:
+ # Arrow optimized Python UDF
+ eval_type = PythonEvalType.SQL_ARROW_BATCHED_UDF
+ else:
+ # Fallback to Regular Python UDF
+ eval_type = PythonEvalType.SQL_BATCHED_UDF
return _create_udf(f, returnType, eval_type)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]