This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6e4936d0fe8f [SPARK-54144][PYTHON] Short Circuit Eval Type Inferences
6e4936d0fe8f is described below

commit 6e4936d0fe8fef932c17a20260a227cdb32142eb
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Sun Nov 2 20:32:19 2025 -0800

    [SPARK-54144][PYTHON] Short Circuit Eval Type Inferences
    
    ### What changes were proposed in this pull request?
    Short Circuit Eval Type Inferences:
    
    ### Why are the changes needed?
    minor optimization that avoid unnecessary inference
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    CI
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #52843 from zhengruifeng/short_circuit_type_infer.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Dongjoon Hyun <[email protected]>
---
 python/pyspark/sql/pandas/typehints.py | 26 +++++++++++++++++---------
 1 file changed, 17 insertions(+), 9 deletions(-)

diff --git a/python/pyspark/sql/pandas/typehints.py 
b/python/pyspark/sql/pandas/typehints.py
index 18858ab0cf68..7c95feee0cfe 100644
--- a/python/pyspark/sql/pandas/typehints.py
+++ b/python/pyspark/sql/pandas/typehints.py
@@ -353,6 +353,9 @@ def infer_group_arrow_eval_type(
             return_annotation, parameter_check_func=lambda t: t == 
pa.RecordBatch
         )
     )
+    if is_iterator_batch:
+        return PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF
+
     # Tuple[pa.Scalar, ...], Iterator[pa.RecordBatch] -> 
Iterator[pa.RecordBatch]
     is_iterator_batch_with_keys = (
         len(parameters_sig) == 2
@@ -364,19 +367,21 @@ def infer_group_arrow_eval_type(
             return_annotation, parameter_check_func=lambda t: t == 
pa.RecordBatch
         )
     )
-
-    if is_iterator_batch or is_iterator_batch_with_keys:
+    if is_iterator_batch_with_keys:
         return PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF
 
     # pa.Table -> pa.Table
     is_table = (
         len(parameters_sig) == 1 and parameters_sig[0] == pa.Table and 
return_annotation == pa.Table
     )
+    if is_table:
+        return PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF
+
     # Tuple[pa.Scalar, ...], pa.Table -> pa.Table
     is_table_with_keys = (
         len(parameters_sig) == 2 and parameters_sig[1] == pa.Table and 
return_annotation == pa.Table
     )
-    if is_table or is_table_with_keys:
+    if is_table_with_keys:
         return PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF
 
     return None
@@ -441,6 +446,9 @@ def infer_group_pandas_eval_type(
             return_annotation, parameter_check_func=lambda t: t == pd.DataFrame
         )
     )
+    if is_iterator_dataframe:
+        return PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF
+
     # Tuple[Any, ...], Iterator[pd.DataFrame] -> Iterator[pd.DataFrame]
     is_iterator_dataframe_with_keys = (
         len(parameters_sig) == 2
@@ -452,8 +460,7 @@ def infer_group_pandas_eval_type(
             return_annotation, parameter_check_func=lambda t: t == pd.DataFrame
         )
     )
-
-    if is_iterator_dataframe or is_iterator_dataframe_with_keys:
+    if is_iterator_dataframe_with_keys:
         return PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF
 
     # pd.DataFrame -> pd.DataFrame
@@ -462,13 +469,16 @@ def infer_group_pandas_eval_type(
         and parameters_sig[0] == pd.DataFrame
         and return_annotation == pd.DataFrame
     )
+    if is_dataframe:
+        return PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF
+
     # Tuple[Any, ...], pd.DataFrame -> pd.DataFrame
     is_dataframe_with_keys = (
         len(parameters_sig) == 2
         and parameters_sig[1] == pd.DataFrame
         and return_annotation == pd.DataFrame
     )
-    if is_dataframe or is_dataframe_with_keys:
+    if is_dataframe_with_keys:
         return PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF
 
     return None
@@ -512,11 +522,9 @@ def check_iterator_annotation(
 def check_union_annotation(
     annotation: Any, parameter_check_func: Optional[Callable[[Any], bool]] = 
None
 ) -> bool:
-    import typing
-
     # Note that we cannot rely on '__origin__' in other type hints as it has 
changed from version
     # to version.
     origin = getattr(annotation, "__origin__", None)
-    return origin == typing.Union and (
+    return origin == Union and (
         parameter_check_func is None or all(map(parameter_check_func, 
annotation.__args__))
     )


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to