This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 77628fc1d01a [SPARK-55390][PYTHON] Consolidate SQL_SCALAR_ARROW_UDF 
wrapper, mapper, and serializer logic
77628fc1d01a is described below

commit 77628fc1d01a56f293185a25c52693e9b6b110a6
Author: Yicong Huang <[email protected]>
AuthorDate: Tue Mar 10 09:14:50 2026 +0800

    [SPARK-55390][PYTHON] Consolidate SQL_SCALAR_ARROW_UDF wrapper, mapper, and 
serializer logic
    
    ### What changes were proposed in this pull request?
    
    This PR consolidates the `SQL_SCALAR_ARROW_UDF` execution path by:
    
    1. Extracting `verify_scalar_result()` as a reusable helper to replace 
inline `verify_result_type` and `verify_result_length` closures in 
`wrap_scalar_arrow_udf`
    2. Removing the dedicated `wrap_scalar_arrow_udf` wrapper and replacing it 
with the general `ArrowStreamGroupSerializer`-based path
    3. Adding `ArrowBatchTransformer.enforce_schema()` to handle schema 
enforcement (column reordering and type coercion) in a centralized way
    
    ### Why are the changes needed?
    
    The scalar Arrow UDF path had its own dedicated wrapper 
(`wrap_scalar_arrow_udf`), mapper, and serializer logic that duplicated 
patterns already available in the consolidated `ArrowStreamGroupSerializer` 
infrastructure. This refactoring reduces code duplication and makes the UDF 
execution paths more consistent.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Existing tests for scalar Arrow UDFs.
    
    ### Benchmark Results
    
    ASV microbenchmark comparison (`ScalarArrowUDFTimeBench`, `repeat=(3, 5, 
5.0)`):
    
    | Scenario | UDF | Before (master) | After (PR) | Delta |
    |---|---|---|---|---|
    | sm_batch_few_col | identity | 71.2±0.5ms | 71.1±0.8ms | -0.1% |
    | sm_batch_few_col | sort | 186±1ms | 184±0.3ms | -1.1% |
    | sm_batch_few_col | nullcheck | 55.7±0.9ms | 56.4±0.4ms | +1.3% |
    | sm_batch_many_col | identity | 24.2±0.2ms | 23.6±0.06ms | -2.5% |
    | sm_batch_many_col | sort | 43.0±0.8ms | 42.1±0.6ms | -2.1% |
    | sm_batch_many_col | nullcheck | 20.6±0.3ms | 20.6±0.1ms | 0% |
    | lg_batch_few_col | identity | 465±1ms | 465±9ms | 0% |
    | lg_batch_few_col | sort | 824±1ms | 825±8ms | +0.1% |
    | lg_batch_few_col | nullcheck | 271±2ms | 278±3ms | +2.6% |
    | lg_batch_many_col | identity | 323±0.1ms | 321±0.9ms | -0.6% |
    | lg_batch_many_col | sort | 358±3ms | 362±4ms | +1.1% |
    | lg_batch_many_col | nullcheck | 326±2ms | 330±0.4ms | +1.2% |
    | pure_ints | identity | 112±2ms | 113±4ms | +0.9% |
    | pure_ints | sort | 179±0.6ms | 174±0.3ms | -2.8% |
    | pure_ints | nullcheck | 89.9±0.6ms | 91.8±0.4ms | +2.1% |
    | pure_floats | identity | 108±1ms | 108±0.2ms | 0% |
    | pure_floats | sort | 569±1ms | 568±1ms | -0.2% |
    | pure_floats | nullcheck | 88.6±0.6ms | 90.4±0.3ms | +2.0% |
    | pure_strings | identity | 120±2ms | 120±0.3ms | 0% |
    | pure_strings | sort | 522±0.9ms | 516±0.7ms | -1.2% |
    | pure_strings | nullcheck | 97.5±0.4ms | 100.0±0.6ms | +2.6% |
    | pure_ts | identity | 110±0.5ms | 110±1ms | 0% |
    | pure_ts | sort | 216±0.7ms | 215±0.9ms | -0.5% |
    | pure_ts | nullcheck | 89.0±0.2ms | 89.4±0.2ms | +0.4% |
    | mixed_types | identity | 105±0.4ms | 105±0.2ms | 0% |
    | mixed_types | sort | 166±0.6ms | 166±0.8ms | 0% |
    | mixed_types | nullcheck | 84.2±0.6ms | 83.1±0.3ms | -1.3% |
    
    Peak memory: no change (within 1M). All deltas within ±3% noise — no 
performance regression.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #54296 from Yicong-Huang/SPARK-55390/refactor/scalar-arrow-udf.
    
    Lead-authored-by: Yicong Huang 
<[email protected]>
    Co-authored-by: Yicong-Huang 
<[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 python/pyspark/sql/conversion.py |  66 ++++++++++++++++++++-
 python/pyspark/worker.py         | 121 ++++++++++++++++++++++++---------------
 2 files changed, 139 insertions(+), 48 deletions(-)

diff --git a/python/pyspark/sql/conversion.py b/python/pyspark/sql/conversion.py
index 249bd2ffcb20..efd03089bbf8 100644
--- a/python/pyspark/sql/conversion.py
+++ b/python/pyspark/sql/conversion.py
@@ -21,7 +21,7 @@ import decimal
 from typing import TYPE_CHECKING, Any, Callable, List, Optional, Sequence, 
Union, overload
 
 import pyspark
-from pyspark.errors import PySparkValueError
+from pyspark.errors import PySparkRuntimeError, PySparkValueError
 from pyspark.sql.pandas.types import (
     _dedup_names,
     _deduplicate_field_names,
@@ -107,6 +107,70 @@ class ArrowBatchTransformer:
             struct = pa.StructArray.from_arrays(batch.columns, 
fields=pa.struct(list(batch.schema)))
         return pa.RecordBatch.from_arrays([struct], ["_0"])
 
+    @classmethod
+    def enforce_schema(
+        cls,
+        batch: "pa.RecordBatch",
+        arrow_schema: "pa.Schema",
+        safecheck: bool = True,
+    ) -> "pa.RecordBatch":
+        """
+        Enforce target schema on a RecordBatch by reordering columns and 
coercing types.
+
+        .. note::
+            Currently this function is only used by UDTF. The error messages
+            are UDTF-specific (see SPARK-55723).
+
+        Parameters
+        ----------
+        batch : pa.RecordBatch
+            Input RecordBatch to transform.
+        arrow_schema : pa.Schema
+            Target Arrow schema. Callers should pre-compute this once via
+            to_arrow_schema() to avoid repeated conversion.
+        safecheck : bool, default True
+            If True, use safe casting (fails on overflow/truncation).
+
+        Returns
+        -------
+        pa.RecordBatch
+            RecordBatch with columns reordered and types coerced to match 
target schema.
+        """
+        import pyarrow as pa
+
+        if batch.num_columns == 0 or len(arrow_schema) == 0:
+            return batch
+
+        # Fast path: schema already matches (ignoring metadata), no work needed
+        if batch.schema.equals(arrow_schema, check_metadata=False):
+            return batch
+
+        # Check if columns are in the same order (by name) as the target 
schema.
+        # If so, use index-based access (faster than name lookup).
+        batch_names = [batch.schema.field(i).name for i in 
range(batch.num_columns)]
+        target_names = [field.name for field in arrow_schema]
+        use_index = batch_names == target_names
+
+        coerced_arrays = []
+        for i, field in enumerate(arrow_schema):
+            arr = batch.column(i) if use_index else batch.column(field.name)
+            if arr.type != field.type:
+                try:
+                    arr = arr.cast(target_type=field.type, safe=safecheck)
+                except (pa.ArrowInvalid, pa.ArrowTypeError):
+                    # TODO(SPARK-55723): Unify error messages for all UDF 
types,
+                    #  not just UDTF.
+                    raise PySparkRuntimeError(
+                        errorClass="RESULT_COLUMNS_MISMATCH_FOR_ARROW_UDTF",
+                        messageParameters={
+                            "expected": str(field.type),
+                            "actual": str(arr.type),
+                        },
+                    )
+            coerced_arrays.append(arr)
+
+        return pa.RecordBatch.from_arrays(coerced_arrays, names=target_names)
+
     @classmethod
     def to_pandas(
         cls,
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 7fbe0849ee63..ddd0b45d9020 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -72,7 +72,7 @@ from pyspark.sql.pandas.serializers import (
     ArrowStreamUDTFSerializer,
     ArrowStreamArrowUDTFSerializer,
 )
-from pyspark.sql.pandas.types import to_arrow_type
+from pyspark.sql.pandas.types import to_arrow_schema, to_arrow_type
 from pyspark.sql.types import (
     ArrayType,
     BinaryType,
@@ -80,6 +80,7 @@ from pyspark.sql.types import (
     MapType,
     Row,
     StringType,
+    StructField,
     StructType,
     _create_row,
     _parse_datatype_json_string,
@@ -266,6 +267,39 @@ def verify_result(expected_type: type) -> Callable[[Any], 
Iterator]:
     return check
 
 
+def verify_scalar_result(result: Any, num_rows: int) -> Any:
+    """
+    Verify a scalar UDF result is array-like and has the expected number of 
rows.
+
+    Parameters
+    ----------
+    result : Any
+        The UDF result to verify.
+    num_rows : int
+        Expected number of rows (must match input batch size).
+    """
+    try:
+        result_length = len(result)
+    except TypeError:
+        raise PySparkTypeError(
+            errorClass="UDF_RETURN_TYPE",
+            messageParameters={
+                "expected": "array-like object",
+                "actual": type(result).__name__,
+            },
+        )
+    if result_length != num_rows:
+        raise PySparkRuntimeError(
+            errorClass="SCHEMA_MISMATCH_FOR_PANDAS_UDF",
+            messageParameters={
+                "udf_type": "arrow_udf",
+                "expected": str(num_rows),
+                "actual": str(result_length),
+            },
+        )
+    return result
+
+
 def wrap_udf(f, args_offsets, kwargs_offsets, return_type):
     func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets, 
kwargs_offsets)
 
@@ -312,46 +346,6 @@ def wrap_scalar_pandas_udf(f, args_offsets, 
kwargs_offsets, return_type, runner_
     )
 
 
-def wrap_scalar_arrow_udf(f, args_offsets, kwargs_offsets, return_type, 
runner_conf):
-    func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets, 
kwargs_offsets)
-
-    arrow_return_type = to_arrow_type(
-        return_type, timezone="UTC", 
prefers_large_types=runner_conf.use_large_var_types
-    )
-
-    def verify_result_type(result):
-        if not hasattr(result, "__len__"):
-            pd_type = "pyarrow.Array"
-            raise PySparkTypeError(
-                errorClass="UDF_RETURN_TYPE",
-                messageParameters={
-                    "expected": pd_type,
-                    "actual": type(result).__name__,
-                },
-            )
-        return result
-
-    def verify_result_length(result, length):
-        if len(result) != length:
-            raise PySparkRuntimeError(
-                errorClass="SCHEMA_MISMATCH_FOR_PANDAS_UDF",
-                messageParameters={
-                    "udf_type": "arrow_udf",
-                    "expected": str(length),
-                    "actual": str(len(result)),
-                },
-            )
-        return result
-
-    return (
-        args_kwargs_offsets,
-        lambda *a: (
-            verify_result_length(verify_result_type(func(*a)), len(a[0])),
-            arrow_return_type,
-        ),
-    )
-
-
 def wrap_arrow_batch_udf(f, args_offsets, kwargs_offsets, return_type, 
runner_conf):
     if runner_conf.use_legacy_pandas_udf_conversion:
         return wrap_arrow_batch_udf_legacy(
@@ -1403,7 +1397,7 @@ def read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index):
     if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF:
         return wrap_scalar_pandas_udf(func, args_offsets, kwargs_offsets, 
return_type, runner_conf)
     elif eval_type == PythonEvalType.SQL_SCALAR_ARROW_UDF:
-        return wrap_scalar_arrow_udf(func, args_offsets, kwargs_offsets, 
return_type, runner_conf)
+        return func, args_offsets, kwargs_offsets, return_type
     elif eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF:
         return wrap_arrow_batch_udf(func, args_offsets, kwargs_offsets, 
return_type, runner_conf)
     elif eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF:
@@ -1413,7 +1407,7 @@ def read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index):
     elif eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF:
         return args_offsets, wrap_pandas_batch_iter_udf(func, return_type, 
runner_conf)
     elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF:
-        return func
+        return func, None, None, None
     elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
         argspec = inspect.getfullargspec(chained_func)  # signature was lost 
when wrapping it
         return args_offsets, wrap_grouped_map_pandas_udf(func, return_type, 
argspec, runner_conf)
@@ -2764,12 +2758,12 @@ def read_udfs(pickleSer, infile, eval_type, 
runner_conf, eval_conf):
             ser = TransformWithStateInPySparkRowInitStateSerializer(
                 
arrow_max_records_per_batch=runner_conf.arrow_max_records_per_batch
             )
-        elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF:
-            ser = ArrowStreamSerializer(write_start_stream=True)
         elif eval_type in (
+            PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
             PythonEvalType.SQL_SCALAR_ARROW_UDF,
-            PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF,
         ):
+            ser = ArrowStreamSerializer(write_start_stream=True)
+        elif eval_type == PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF:
             # Arrow cast and safe check are always enabled
             ser = ArrowStreamArrowUDFSerializer(safecheck=True, 
arrow_cast=True)
         elif (
@@ -2831,7 +2825,7 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf, 
eval_conf):
         import pyarrow as pa
 
         assert num_udfs == 1, "One MAP_ARROW_ITER UDF expected here."
-        udf_func: Callable[[Iterator[pa.RecordBatch]], 
Iterator[pa.RecordBatch]] = udfs[0]
+        udf_func: Callable[[Iterator[pa.RecordBatch]], 
Iterator[pa.RecordBatch]] = udfs[0][0]
 
         def func(split_index: int, batches: Iterator[pa.RecordBatch]) -> 
Iterator[pa.RecordBatch]:
             """Apply mapInArrow UDF"""
@@ -2851,6 +2845,39 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf, 
eval_conf):
         # profiling is not supported for UDF
         return func, None, ser, ser
 
+    if eval_type == PythonEvalType.SQL_SCALAR_ARROW_UDF:
+        import pyarrow as pa
+
+        col_names = ["_%d" % i for i in range(len(udfs))]
+        combined_arrow_schema = to_arrow_schema(
+            StructType([StructField(n, rt) for n, (_, _, _, rt) in 
zip(col_names, udfs)]),
+            timezone="UTC",
+            prefers_large_types=runner_conf.use_large_var_types,
+        )
+
+        def func(split_index: int, batches: Iterator[pa.RecordBatch]) -> 
Iterator[pa.RecordBatch]:
+            """Apply scalar Arrow UDFs"""
+
+            for input_batch in batches:
+                output_batch = pa.RecordBatch.from_arrays(
+                    [
+                        udf_func(
+                            *[input_batch.column(o) for o in args_offsets],
+                            **{k: input_batch.column(v) for k, v in 
kwargs_offsets.items()},
+                        )
+                        for udf_func, args_offsets, kwargs_offsets, _ in udfs
+                    ],
+                    col_names,
+                )
+                output_batch = ArrowBatchTransformer.enforce_schema(
+                    output_batch, combined_arrow_schema
+                )
+                verify_scalar_result(output_batch, input_batch.num_rows)
+                yield output_batch
+
+        # profiling is not supported for UDF
+        return func, None, ser, ser
+
     is_scalar_iter = eval_type in (
         PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
         PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF,


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to