This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 59dd406ffab6 [SPARK-48516][PYTHON][CONNECT] Turn on Arrow optimization
for Python UDFs by default
59dd406ffab6 is described below
commit 59dd406ffab6f7df7f36fe7befe121822e68bf00
Author: Xinrong Meng <[email protected]>
AuthorDate: Mon Feb 10 11:39:46 2025 -0800
[SPARK-48516][PYTHON][CONNECT] Turn on Arrow optimization for Python UDFs
by default
### What changes were proposed in this pull request?
Turn on Arrow optimization for Python UDFs by default
### Why are the changes needed?
Arrow optimization was introduced in 3.4.0. See
[SPARK-40307](https://issues.apache.org/jira/browse/SPARK-40307) for more
context.
Arrow-optimized Python UDF is approximately 1.6 times faster than the
original pickled Python UDF. More details can be found in [this blog
post](https://www.databricks.com/blog/arrow-optimized-python-udfs-apache-sparktm-35).
In version 4.0.0, we propose enabling the optimization by default. If
PyArrow is not installed, it will fall back to the original pickled Python UDF.
### Does this PR introduce _any_ user-facing change?
Yes
### How was this patch tested?
Existing tests
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #49482 from xinrong-meng/arrow_on.
Authored-by: Xinrong Meng <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
---
python/docs/source/user_guide/sql/arrow_pandas.rst | 4 ++--
python/docs/source/user_guide/sql/type_conversions.rst | 2 +-
python/pyspark/ml/base.py | 3 ++-
python/pyspark/ml/classification.py | 7 +++++--
python/pyspark/ml/tuning.py | 6 +++++-
python/pyspark/sql/connect/udf.py | 10 ++++++++++
python/pyspark/sql/functions/builtin.py | 3 ++-
python/pyspark/sql/udf.py | 12 ++++++++++--
python/pyspark/sql/utils.py | 9 +++++++++
.../main/scala/org/apache/spark/sql/internal/SQLConf.scala | 2 +-
10 files changed, 47 insertions(+), 11 deletions(-)
diff --git a/python/docs/source/user_guide/sql/arrow_pandas.rst
b/python/docs/source/user_guide/sql/arrow_pandas.rst
index fde40140110f..b9e389f8fe7d 100644
--- a/python/docs/source/user_guide/sql/arrow_pandas.rst
+++ b/python/docs/source/user_guide/sql/arrow_pandas.rst
@@ -356,8 +356,8 @@ Arrow Python UDFs are user defined functions that are
executed row-by-row, utili
transfer and serialization. To define an Arrow Python UDF, you can use the
:meth:`udf` decorator or wrap the function
with the :meth:`udf` method, ensuring the ``useArrow`` parameter is set to
True. Additionally, you can enable Arrow
optimization for Python UDFs throughout the entire SparkSession by setting the
Spark configuration
-``spark.sql.execution.pythonUDF.arrow.enabled`` to true. It's important to
note that the Spark configuration takes
-effect only when ``useArrow`` is either not set or set to None.
+``spark.sql.execution.pythonUDF.arrow.enabled`` to true, which is the default.
It's important to note that the Spark
+configuration takes effect only when ``useArrow`` is either not set or set to
None.
The type hints for Arrow Python UDFs should be specified in the same way as
for default, pickled Python UDFs.
diff --git a/python/docs/source/user_guide/sql/type_conversions.rst
b/python/docs/source/user_guide/sql/type_conversions.rst
index 2f13701995ef..80f8aa83db7e 100644
--- a/python/docs/source/user_guide/sql/type_conversions.rst
+++ b/python/docs/source/user_guide/sql/type_conversions.rst
@@ -57,7 +57,7 @@ are listed below:
- Default
* - spark.sql.execution.pythonUDF.arrow.enabled
- Enable PyArrow in PySpark. See more `here <arrow_pandas.rst>`_.
- - False
+ - True
* - spark.sql.pyspark.inferNestedDictAsStruct.enabled
- When enabled, nested dictionaries are inferred as StructType.
Otherwise, they are inferred as MapType.
- False
diff --git a/python/pyspark/ml/base.py b/python/pyspark/ml/base.py
index 224ef34fd5ed..0bdfa27fc702 100644
--- a/python/pyspark/ml/base.py
+++ b/python/pyspark/ml/base.py
@@ -328,7 +328,8 @@ class UnaryTransformer(HasInputCol, HasOutputCol,
Transformer):
def _transform(self, dataset: DataFrame) -> DataFrame:
self.transformSchema(dataset.schema)
- transformUDF = udf(self.createTransformFunc(), self.outputDataType())
+ # TODO(SPARK-48515): Use Arrow Python UDF
+ transformUDF = udf(self.createTransformFunc(), self.outputDataType(),
useArrow=False)
transformedDataset = dataset.withColumn(
self.getOutputCol(), transformUDF(dataset[self.getInputCol()])
)
diff --git a/python/pyspark/ml/classification.py
b/python/pyspark/ml/classification.py
index 3a6425d0bfcd..0cd4f60ed7ae 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -3835,12 +3835,13 @@ class OneVsRestModel(
)
def _transform(self, dataset: DataFrame) -> DataFrame:
+ # TODO(SPARK-48515): Use Arrow Python UDF
# determine the input columns: these need to be passed through
origCols = dataset.columns
# add an accumulator column to store predictions of all the models
accColName = "mbc$acc" + str(uuid.uuid4())
- initUDF = udf(lambda _: [], ArrayType(DoubleType()))
+ initUDF = udf(lambda _: [], ArrayType(DoubleType()), useArrow=False)
newDataset = dataset.withColumn(accColName,
initUDF(dataset[origCols[0]]))
# persist if underlying dataset is not persistent.
@@ -3860,6 +3861,7 @@ class OneVsRestModel(
updateUDF = udf(
lambda predictions, prediction: predictions +
[prediction.tolist()[1]],
ArrayType(DoubleType()),
+ useArrow=False,
)
transformedDataset =
model.transform(aggregatedDataset).select(*columns)
updatedDataset = transformedDataset.withColumn(
@@ -3884,7 +3886,7 @@ class OneVsRestModel(
predArray.append(x)
return Vectors.dense(predArray)
- rawPredictionUDF = udf(func, VectorUDT())
+ rawPredictionUDF = udf(func, VectorUDT(), useArrow=False)
aggregatedDataset = aggregatedDataset.withColumn(
self.getRawPredictionCol(),
rawPredictionUDF(aggregatedDataset[accColName])
)
@@ -3896,6 +3898,7 @@ class OneVsRestModel(
max(enumerate(predictions), key=operator.itemgetter(1))[0]
),
DoubleType(),
+ useArrow=False,
)
aggregatedDataset = aggregatedDataset.withColumn(
self.getPredictionCol(),
labelUDF(aggregatedDataset[accColName])
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index d3506bf1c6b0..c85f8438079f 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -906,8 +906,12 @@ class CrossValidator(
from pyspark.sql.connect.udf import UserDefinedFunction
else:
from pyspark.sql.functions import UserDefinedFunction # type:
ignore[assignment]
+ from pyspark.util import PythonEvalType
- checker_udf = UserDefinedFunction(checker, BooleanType())
+ # TODO(SPARK-48515): Use Arrow Python UDF
+ checker_udf = UserDefinedFunction(
+ checker, BooleanType(), evalType=PythonEvalType.SQL_BATCHED_UDF
+ )
for i in range(nFolds):
training = dataset.filter(checker_udf(dataset[foldCol]) &
(col(foldCol) != lit(i)))
validation = dataset.filter(
diff --git a/python/pyspark/sql/connect/udf.py
b/python/pyspark/sql/connect/udf.py
index ab3a2da48ba5..6045e441222d 100644
--- a/python/pyspark/sql/connect/udf.py
+++ b/python/pyspark/sql/connect/udf.py
@@ -41,6 +41,7 @@ from pyspark.sql.udf import (
UDFRegistration as PySparkUDFRegistration,
UserDefinedFunction as PySparkUserDefinedFunction,
)
+from pyspark.sql.utils import has_arrow
from pyspark.errors import PySparkTypeError, PySparkRuntimeError
if TYPE_CHECKING:
@@ -58,6 +59,7 @@ def _create_py_udf(
returnType: "DataTypeOrString",
useArrow: Optional[bool] = None,
) -> "UserDefinedFunctionLike":
+ is_arrow_enabled = False
if useArrow is None:
is_arrow_enabled = False
try:
@@ -78,6 +80,14 @@ def _create_py_udf(
eval_type: int = PythonEvalType.SQL_BATCHED_UDF
+ if is_arrow_enabled and not has_arrow:
+ is_arrow_enabled = False
+ warnings.warn(
+ "Arrow optimization failed to enable because PyArrow is not
installed. "
+ "Falling back to a non-Arrow-optimized UDF.",
+ RuntimeWarning,
+ )
+
if is_arrow_enabled:
try:
is_func_with_args = len(getfullargspec(f).args) > 0
diff --git a/python/pyspark/sql/functions/builtin.py
b/python/pyspark/sql/functions/builtin.py
index 2b6d8569fdf8..4c9f8cad34f1 100644
--- a/python/pyspark/sql/functions/builtin.py
+++ b/python/pyspark/sql/functions/builtin.py
@@ -26350,7 +26350,8 @@ def udf(
Defaults to :class:`StringType`.
useArrow : bool, optional
whether to use Arrow to optimize the (de)serialization. When it is
None, the
- Spark config "spark.sql.execution.pythonUDF.arrow.enabled" takes
effect.
+ Spark config "spark.sql.execution.pythonUDF.arrow.enabled" takes
effect,
+ which is "true" by default.
Examples
--------
diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py
index cf093bd93643..2fd75390f48d 100644
--- a/python/pyspark/sql/udf.py
+++ b/python/pyspark/sql/udf.py
@@ -34,7 +34,7 @@ from pyspark.sql.types import (
StructType,
_parse_datatype_string,
)
-from pyspark.sql.utils import get_active_spark_context
+from pyspark.sql.utils import get_active_spark_context, has_arrow
from pyspark.sql.pandas.types import to_arrow_type
from pyspark.sql.pandas.utils import require_minimum_pandas_version,
require_minimum_pyarrow_version
from pyspark.errors import PySparkTypeError, PySparkNotImplementedError,
PySparkRuntimeError
@@ -118,7 +118,7 @@ def _create_py_udf(
# Note: The values of 'SQL Type' are DDL formatted strings, which can be
used as `returnType`s.
# Note: The values inside the table are generated by `repr`. X' means it
throws an exception
# during the conversion.
-
+ is_arrow_enabled = False
if useArrow is None:
from pyspark.sql import SparkSession
@@ -131,6 +131,14 @@ def _create_py_udf(
else:
is_arrow_enabled = useArrow
+ if is_arrow_enabled and not has_arrow:
+ is_arrow_enabled = False
+ warnings.warn(
+ "Arrow optimization failed to enable because PyArrow is not
installed. "
+ "Falling back to a non-Arrow-optimized UDF.",
+ RuntimeWarning,
+ )
+
eval_type: int = PythonEvalType.SQL_BATCHED_UDF
if is_arrow_enabled:
diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
index b0782d04cba3..63beda40dc52 100644
--- a/python/pyspark/sql/utils.py
+++ b/python/pyspark/sql/utils.py
@@ -63,6 +63,15 @@ if TYPE_CHECKING:
from pyspark.pandas._typing import IndexOpsLike, SeriesOrIndex
+has_arrow: bool = False
+try:
+ import pyarrow # noqa: F401
+
+ has_arrow = True
+except ImportError:
+ pass
+
+
FuncT = TypeVar("FuncT", bound=Callable[..., Any])
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 64142b2b61d1..84b8e1264be9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -3512,7 +3512,7 @@ object SQLConf {
"can only be enabled when the given function takes at least one
argument.")
.version("3.4.0")
.booleanConf
- .createWithDefault(false)
+ .createWithDefault(true)
val PYTHON_UDF_ARROW_CONCURRENCY_LEVEL =
buildConf("spark.sql.execution.pythonUDF.arrow.concurrency.level")
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]