This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 0c16e93dce06 [SPARK-50446][PYTHON] Concurrent level in Arrow-optimized 
Python UDF
0c16e93dce06 is described below

commit 0c16e93dce06639b49fe6641aedcd1e9fc4df96a
Author: Hyukjin Kwon <[email protected]>
AuthorDate: Fri Nov 29 13:39:53 2024 +0900

    [SPARK-50446][PYTHON] Concurrent level in Arrow-optimized Python UDF
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to add a configuration for async execution for Python UDF 
with Arrow optimization. One usecase is, for exmaple, to have RESTful API 
requests within Python UDF, and it slows down by I/O. By this configuration, 
those I/O requests can happen in parallel.
    
    ### Why are the changes needed?
    
    In order to speed up UDF executions. For example, the code below:
    
    ```python
    spark.conf.set("spark.sql.execution.pythonUDF.arrow.concurrency.level", 10)
    spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", "true")
    
    from pyspark.sql.functions import udf, col
    import time
    
    udf
    def my_rest_func(x):
        import requests
        requests.get("https://httpbin.org/get";)
    
    start_time = time.time()
    _ = spark.range(100).coalesce(1).select(my_rest_func(col("id"))).collect()
    print(time.time() - start_time)
    ```
    
    can be 10x faster.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, it adds a new configuration called 
`spark.sql.execution.pythonUDF.arrow.concurrency.level`.
    
    ### How was this patch tested?
    
    Tested as shown above, and unittests were also added.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #49000 from HyukjinKwon/async-exec.
    
    Authored-by: Hyukjin Kwon <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/sql/tests/test_arrow_python_udf.py  | 14 ++++++++++++++
 python/pyspark/worker.py                           | 22 +++++++++++++++++-----
 .../org/apache/spark/sql/internal/SQLConf.scala    | 13 +++++++++++++
 .../sql/execution/python/ArrowPythonRunner.scala   |  5 ++++-
 4 files changed, 48 insertions(+), 6 deletions(-)

diff --git a/python/pyspark/sql/tests/test_arrow_python_udf.py 
b/python/pyspark/sql/tests/test_arrow_python_udf.py
index 095414334848..a3fd8c01992a 100644
--- a/python/pyspark/sql/tests/test_arrow_python_udf.py
+++ b/python/pyspark/sql/tests/test_arrow_python_udf.py
@@ -238,6 +238,20 @@ class PythonUDFArrowTests(PythonUDFArrowTestsMixin, 
ReusedSQLTestCase):
             super(PythonUDFArrowTests, cls).tearDownClass()
 
 
+class AsyncPythonUDFArrowTests(PythonUDFArrowTests):
+    @classmethod
+    def setUpClass(cls):
+        super(AsyncPythonUDFArrowTests, cls).setUpClass()
+        
cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.concurrency.level", "4")
+
+    @classmethod
+    def tearDownClass(cls):
+        try:
+            
cls.spark.conf.unset("spark.sql.execution.pythonUDF.arrow.concurrency.level")
+        finally:
+            super(AsyncPythonUDFArrowTests, cls).tearDownClass()
+
+
 if __name__ == "__main__":
     from pyspark.sql.tests.test_arrow_python_udf import *  # noqa: F401
 
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 1ebc04520eca..a11465e7a323 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -154,7 +154,7 @@ def wrap_scalar_pandas_udf(f, args_offsets, kwargs_offsets, 
return_type):
     )
 
 
-def wrap_arrow_batch_udf(f, args_offsets, kwargs_offsets, return_type):
+def wrap_arrow_batch_udf(f, args_offsets, kwargs_offsets, return_type, 
runner_conf):
     import pandas as pd
 
     func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets, 
kwargs_offsets)
@@ -172,9 +172,21 @@ def wrap_arrow_batch_udf(f, args_offsets, kwargs_offsets, 
return_type):
     elif type(return_type) == BinaryType:
         result_func = lambda r: bytes(r) if r is not None else r  # noqa: E731
 
-    @fail_on_stopiteration
-    def evaluate(*args: pd.Series) -> pd.Series:
-        return pd.Series([result_func(func(*row)) for row in zip(*args)])
+    if "spark.sql.execution.pythonUDF.arrow.concurrency.level" in runner_conf:
+        from concurrent.futures import ThreadPoolExecutor
+
+        c = 
int(runner_conf["spark.sql.execution.pythonUDF.arrow.concurrency.level"])
+
+        @fail_on_stopiteration
+        def evaluate(*args: pd.Series) -> pd.Series:
+            with ThreadPoolExecutor(max_workers=c) as pool:
+                return pd.Series(list(pool.map(lambda row: 
result_func(func(*row)), zip(*args))))
+
+    else:
+
+        @fail_on_stopiteration
+        def evaluate(*args: pd.Series) -> pd.Series:
+            return pd.Series([result_func(func(*row)) for row in zip(*args)])
 
     def verify_result_length(result, length):
         if len(result) != length:
@@ -855,7 +867,7 @@ def read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index, profil
     if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF:
         return wrap_scalar_pandas_udf(func, args_offsets, kwargs_offsets, 
return_type)
     elif eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF:
-        return wrap_arrow_batch_udf(func, args_offsets, kwargs_offsets, 
return_type)
+        return wrap_arrow_batch_udf(func, args_offsets, kwargs_offsets, 
return_type, runner_conf)
     elif eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF:
         return args_offsets, wrap_pandas_batch_iter_udf(func, return_type)
     elif eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF:
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index e8031580c116..2a05508a1754 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -3321,6 +3321,17 @@ object SQLConf {
       .booleanConf
       .createWithDefault(false)
 
+  val PYTHON_UDF_ARROW_CONCURRENCY_LEVEL =
+    buildConf("spark.sql.execution.pythonUDF.arrow.concurrency.level")
+      .doc("The level of concurrency to execute Arrow-optimized Python UDF. " +
+        "This can be useful if Python UDFs use I/O intensively.")
+      .version("4.0.0")
+      .intConf
+      .checkValue(_ > 1,
+        "The value of spark.sql.execution.pythonUDF.arrow.concurrency.level" +
+          " must be more than one.")
+      .createOptional
+
   val PYTHON_TABLE_UDF_ARROW_ENABLED =
     buildConf("spark.sql.execution.pythonUDTF.arrow.enabled")
       .doc("Enable Arrow optimization for Python UDTFs.")
@@ -5997,6 +6008,8 @@ class SQLConf extends Serializable with Logging with 
SqlApiConf {
 
   def pythonUDFWorkerFaulthandlerEnabled: Boolean = 
getConf(PYTHON_UDF_WORKER_FAULTHANLDER_ENABLED)
 
+  def pythonUDFArrowConcurrencyLevel: Option[Int] = 
getConf(PYTHON_UDF_ARROW_CONCURRENCY_LEVEL)
+
   def pysparkPlotMaxRows: Int = getConf(PYSPARK_PLOT_MAX_ROWS)
 
   def arrowSparkREnabled: Boolean = getConf(ARROW_SPARKR_EXECUTION_ENABLED)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
index a555d660ea1a..72e9c5210194 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
@@ -116,6 +116,9 @@ object ArrowPythonRunner {
       conf.pandasGroupedMapAssignColumnsByName.toString)
     val arrowSafeTypeCheck = Seq(SQLConf.PANDAS_ARROW_SAFE_TYPE_CONVERSION.key 
->
       conf.arrowSafeTypeConversion.toString)
-    Map(timeZoneConf ++ pandasColsByName ++ arrowSafeTypeCheck: _*)
+    val arrowAyncParallelism = conf.pythonUDFArrowConcurrencyLevel.map(v =>
+      Seq(SQLConf.PYTHON_UDF_ARROW_CONCURRENCY_LEVEL.key -> v.toString)
+    ).getOrElse(Seq.empty)
+    Map(timeZoneConf ++ pandasColsByName ++ arrowSafeTypeCheck ++ 
arrowAyncParallelism: _*)
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to