This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 191801b7f34e [SPARK-55665][PYTHON] Unify how workers establish 
connection with the executor
191801b7f34e is described below

commit 191801b7f34e5c0db2d9a2c62f28117ccb69ac2e
Author: Tian Gao <[email protected]>
AuthorDate: Thu Feb 26 09:52:46 2026 +0800

    [SPARK-55665][PYTHON] Unify how workers establish connection with the 
executor
    
    ### What changes were proposed in this pull request?
    
    Unify all the sock file connections from different worker files together. 
And guarantee an explicit flush and close for the sock file.
    
    ### Why are the changes needed?
    
    We now copy/paste this piece of code all over our code base and it 
introduces a few issues.
    * Code duplication, obviously.
    * During the copy/paste, we actually made some mistake. 
`data_source_pushdown_filters.py` forgets to write pid back but we never test 
it.
    * We can't guarantee a flush and close for sock file. Now we rely on gc to 
do that but that's not reliable. We have issues for simple workers.
    * In the future, if we want to drop the PID communication (TODO) or for now 
if we want to do an explicit flush, we need to change all over our code base.
    
    It's best to just organize the code at a single place.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Locally `test_python_datasource` passed, the rest is on CI.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Yes, Cursor(claude-4.6-opus-high).
    
    Closes #54458 from gaogaotiantian/sockfile-to-executor.
    
    Authored-by: Tian Gao <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 .../streaming/worker/foreach_batch_worker.py       | 15 +++------------
 .../connect/streaming/worker/listener_worker.py    | 15 +++------------
 .../streaming/python_streaming_source_runner.py    | 16 ++++------------
 .../transform_with_state_driver_worker.py          | 14 +++-----------
 python/pyspark/sql/worker/analyze_udtf.py          | 15 +++------------
 .../pyspark/sql/worker/commit_data_source_write.py | 15 +++------------
 python/pyspark/sql/worker/create_data_source.py    | 14 +++-----------
 .../sql/worker/data_source_pushdown_filters.py     | 12 +++---------
 python/pyspark/sql/worker/lookup_data_sources.py   | 15 +++------------
 python/pyspark/sql/worker/plan_data_source_read.py | 14 +++-----------
 .../sql/worker/python_streaming_sink_runner.py     | 14 +++-----------
 .../pyspark/sql/worker/write_into_data_source.py   | 17 +++--------------
 python/pyspark/worker.py                           | 15 ++++-----------
 python/pyspark/worker_util.py                      | 22 +++++++++++++++++++++-
 14 files changed, 62 insertions(+), 151 deletions(-)

diff --git 
a/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py 
b/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py
index b819634adb5a..18bb459a2918 100644
--- a/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py
+++ b/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py
@@ -21,7 +21,7 @@ Usually this is ran on the driver side of the Spark Connect 
Server.
 """
 import os
 
-from pyspark.util import local_connect_and_auth
+from pyspark.worker_util import get_sock_file_to_executor
 from pyspark.serializers import (
     write_int,
     read_long,
@@ -90,14 +90,5 @@ def main(infile: IO, outfile: IO) -> None:
 
 
 if __name__ == "__main__":
-    # Read information about how to connect back to the JVM from the 
environment.
-    conn_info = os.environ.get(
-        "PYTHON_WORKER_FACTORY_SOCK_PATH", 
int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
-    )
-    auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
-    (sock_file, sock) = local_connect_and_auth(conn_info, auth_secret)
-    # There could be a long time between each micro batch.
-    sock.settimeout(None)
-    write_int(os.getpid(), sock_file)
-    sock_file.flush()
-    main(sock_file, sock_file)
+    with get_sock_file_to_executor(timeout=None) as sock_file:
+        main(sock_file, sock_file)
diff --git a/python/pyspark/sql/connect/streaming/worker/listener_worker.py 
b/python/pyspark/sql/connect/streaming/worker/listener_worker.py
index 2c6ce8715994..994339b5d90d 100644
--- a/python/pyspark/sql/connect/streaming/worker/listener_worker.py
+++ b/python/pyspark/sql/connect/streaming/worker/listener_worker.py
@@ -22,7 +22,7 @@ Usually this is ran on the driver side of the Spark Connect 
Server.
 import os
 import json
 
-from pyspark.util import local_connect_and_auth
+from pyspark.worker_util import get_sock_file_to_executor
 from pyspark.serializers import (
     read_int,
     write_int,
@@ -104,14 +104,5 @@ def main(infile: IO, outfile: IO) -> None:
 
 
 if __name__ == "__main__":
-    # Read information about how to connect back to the JVM from the 
environment.
-    conn_info = os.environ.get(
-        "PYTHON_WORKER_FACTORY_SOCK_PATH", 
int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
-    )
-    auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
-    (sock_file, sock) = local_connect_and_auth(conn_info, auth_secret)
-    # There could be a long time between each listener event.
-    sock.settimeout(None)
-    write_int(os.getpid(), sock_file)
-    sock_file.flush()
-    main(sock_file, sock_file)
+    with get_sock_file_to_executor(timeout=None) as sock_file:
+        main(sock_file, sock_file)
diff --git a/python/pyspark/sql/streaming/python_streaming_source_runner.py 
b/python/pyspark/sql/streaming/python_streaming_source_runner.py
index 31f70a59dbfb..44811c84548f 100644
--- a/python/pyspark/sql/streaming/python_streaming_source_runner.py
+++ b/python/pyspark/sql/streaming/python_streaming_source_runner.py
@@ -48,8 +48,9 @@ from pyspark.sql.types import (
     StructType,
 )
 from pyspark.sql.worker.plan_data_source_read import records_to_arrow_batches
-from pyspark.util import handle_worker_exception, local_connect_and_auth
+from pyspark.util import handle_worker_exception
 from pyspark.worker_util import (
+    get_sock_file_to_executor,
     check_python_version,
     read_command,
     pickleSer,
@@ -310,14 +311,5 @@ def main(infile: IO, outfile: IO) -> None:
 
 
 if __name__ == "__main__":
-    # Read information about how to connect back to the JVM from the 
environment.
-    conn_info = os.environ.get(
-        "PYTHON_WORKER_FACTORY_SOCK_PATH", 
int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
-    )
-    auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
-    (sock_file, sock) = local_connect_and_auth(conn_info, auth_secret)
-    # Prevent the socket from timeout error when query trigger interval is 
large.
-    sock.settimeout(None)
-    write_int(os.getpid(), sock_file)
-    sock_file.flush()
-    main(sock_file, sock_file)
+    with get_sock_file_to_executor(timeout=None) as sock_file:
+        main(sock_file, sock_file)
diff --git a/python/pyspark/sql/streaming/transform_with_state_driver_worker.py 
b/python/pyspark/sql/streaming/transform_with_state_driver_worker.py
index 3fe7f68a99e5..a05d616eda2c 100644
--- a/python/pyspark/sql/streaming/transform_with_state_driver_worker.py
+++ b/python/pyspark/sql/streaming/transform_with_state_driver_worker.py
@@ -15,11 +15,10 @@
 # limitations under the License.
 #
 
-import os
 import json
 from typing import Any, Iterator, TYPE_CHECKING
 
-from pyspark.util import local_connect_and_auth
+from pyspark.worker_util import get_sock_file_to_executor
 from pyspark.serializers import (
     write_int,
     read_int,
@@ -95,12 +94,5 @@ def main(infile: IO, outfile: IO) -> None:
 
 
 if __name__ == "__main__":
-    # Read information about how to connect back to the JVM from the 
environment.
-    conn_info = os.environ.get(
-        "PYTHON_WORKER_FACTORY_SOCK_PATH", 
int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
-    )
-    auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
-    (sock_file, sock) = local_connect_and_auth(conn_info, auth_secret)
-    write_int(os.getpid(), sock_file)
-    sock_file.flush()
-    main(sock_file, sock_file)
+    with get_sock_file_to_executor() as sock_file:
+        main(sock_file, sock_file)
diff --git a/python/pyspark/sql/worker/analyze_udtf.py 
b/python/pyspark/sql/worker/analyze_udtf.py
index 7265138202cd..9328ab2f199b 100644
--- a/python/pyspark/sql/worker/analyze_udtf.py
+++ b/python/pyspark/sql/worker/analyze_udtf.py
@@ -16,7 +16,6 @@
 #
 
 import inspect
-import os
 from textwrap import dedent
 from typing import Dict, List, IO, Tuple
 
@@ -32,8 +31,8 @@ from pyspark.sql.functions import OrderingColumn, 
PartitioningColumn, SelectedCo
 from pyspark.sql.types import _parse_datatype_json_string, StructType
 from pyspark.sql.udtf import AnalyzeArgument, AnalyzeResult
 from pyspark.sql.worker.utils import worker_run
-from pyspark.util import local_connect_and_auth
 from pyspark.worker_util import (
+    get_sock_file_to_executor,
     read_command,
     pickleSer,
     utf8_deserializer,
@@ -238,13 +237,5 @@ def main(infile: IO, outfile: IO) -> None:
 
 
 if __name__ == "__main__":
-    # Read information about how to connect back to the JVM from the 
environment.
-    conn_info = os.environ.get(
-        "PYTHON_WORKER_FACTORY_SOCK_PATH", 
int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
-    )
-    auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
-    (sock_file, _) = local_connect_and_auth(conn_info, auth_secret)
-    # TODO: Remove the following two lines and use `Process.pid()` when we 
drop JDK 8.
-    write_int(os.getpid(), sock_file)
-    sock_file.flush()
-    main(sock_file, sock_file)
+    with get_sock_file_to_executor() as sock_file:
+        main(sock_file, sock_file)
diff --git a/python/pyspark/sql/worker/commit_data_source_write.py 
b/python/pyspark/sql/worker/commit_data_source_write.py
index 6838a32db398..f16234d29b3d 100644
--- a/python/pyspark/sql/worker/commit_data_source_write.py
+++ b/python/pyspark/sql/worker/commit_data_source_write.py
@@ -14,7 +14,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
-import os
 from typing import IO
 
 from pyspark.errors import PySparkAssertionError
@@ -26,8 +25,7 @@ from pyspark.serializers import (
 )
 from pyspark.sql.datasource import DataSourceWriter, WriterCommitMessage
 from pyspark.sql.worker.utils import worker_run
-from pyspark.util import local_connect_and_auth
-from pyspark.worker_util import pickleSer
+from pyspark.worker_util import get_sock_file_to_executor, pickleSer
 
 
 def _main(infile: IO, outfile: IO) -> None:
@@ -78,12 +76,5 @@ def main(infile: IO, outfile: IO) -> None:
 
 
 if __name__ == "__main__":
-    # Read information about how to connect back to the JVM from the 
environment.
-    conn_info = os.environ.get(
-        "PYTHON_WORKER_FACTORY_SOCK_PATH", 
int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
-    )
-    auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
-    (sock_file, _) = local_connect_and_auth(conn_info, auth_secret)
-    write_int(os.getpid(), sock_file)
-    sock_file.flush()
-    main(sock_file, sock_file)
+    with get_sock_file_to_executor() as sock_file:
+        main(sock_file, sock_file)
diff --git a/python/pyspark/sql/worker/create_data_source.py 
b/python/pyspark/sql/worker/create_data_source.py
index 625b08088e60..41899774a7d3 100644
--- a/python/pyspark/sql/worker/create_data_source.py
+++ b/python/pyspark/sql/worker/create_data_source.py
@@ -15,7 +15,6 @@
 # limitations under the License.
 #
 import inspect
-import os
 from typing import IO
 
 from pyspark.errors import PySparkAssertionError, PySparkTypeError
@@ -29,8 +28,8 @@ from pyspark.serializers import (
 from pyspark.sql.datasource import DataSource, CaseInsensitiveDict
 from pyspark.sql.types import _parse_datatype_json_string, StructType
 from pyspark.sql.worker.utils import worker_run
-from pyspark.util import local_connect_and_auth
 from pyspark.worker_util import (
+    get_sock_file_to_executor,
     read_command,
     pickleSer,
     utf8_deserializer,
@@ -146,12 +145,5 @@ def main(infile: IO, outfile: IO) -> None:
 
 
 if __name__ == "__main__":
-    # Read information about how to connect back to the JVM from the 
environment.
-    conn_info = os.environ.get(
-        "PYTHON_WORKER_FACTORY_SOCK_PATH", 
int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
-    )
-    auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
-    (sock_file, _) = local_connect_and_auth(conn_info, auth_secret)
-    write_int(os.getpid(), sock_file)
-    sock_file.flush()
-    main(sock_file, sock_file)
+    with get_sock_file_to_executor() as sock_file:
+        main(sock_file, sock_file)
diff --git a/python/pyspark/sql/worker/data_source_pushdown_filters.py 
b/python/pyspark/sql/worker/data_source_pushdown_filters.py
index 06b71e1ca8f1..a64993642299 100644
--- a/python/pyspark/sql/worker/data_source_pushdown_filters.py
+++ b/python/pyspark/sql/worker/data_source_pushdown_filters.py
@@ -17,7 +17,6 @@
 
 import base64
 import json
-import os
 import typing
 from dataclasses import dataclass, field
 from typing import IO, Type, Union
@@ -47,8 +46,8 @@ from pyspark.sql.datasource import (
 from pyspark.sql.types import StructType, VariantVal, 
_parse_datatype_json_string
 from pyspark.sql.worker.plan_data_source_read import 
write_read_func_and_partitions
 from pyspark.sql.worker.utils import worker_run
-from pyspark.util import local_connect_and_auth
 from pyspark.worker_util import (
+    get_sock_file_to_executor,
     pickleSer,
     read_command,
 )
@@ -226,10 +225,5 @@ def main(infile: IO, outfile: IO) -> None:
 
 
 if __name__ == "__main__":
-    # Read information about how to connect back to the JVM from the 
environment.
-    conn_info = os.environ.get(
-        "PYTHON_WORKER_FACTORY_SOCK_PATH", 
int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
-    )
-    auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
-    (sock_file, _) = local_connect_and_auth(conn_info, auth_secret)
-    main(sock_file, sock_file)
+    with get_sock_file_to_executor() as sock_file:
+        main(sock_file, sock_file)
diff --git a/python/pyspark/sql/worker/lookup_data_sources.py 
b/python/pyspark/sql/worker/lookup_data_sources.py
index e432f40d6904..aa5e7c4abb5a 100644
--- a/python/pyspark/sql/worker/lookup_data_sources.py
+++ b/python/pyspark/sql/worker/lookup_data_sources.py
@@ -16,7 +16,6 @@
 #
 from importlib import import_module
 from pkgutil import iter_modules
-import os
 from typing import IO
 
 from pyspark.serializers import (
@@ -25,8 +24,7 @@ from pyspark.serializers import (
 )
 from pyspark.sql.datasource import DataSource
 from pyspark.sql.worker.utils import worker_run
-from pyspark.util import local_connect_and_auth
-from pyspark.worker_util import pickleSer
+from pyspark.worker_util import get_sock_file_to_executor, pickleSer
 
 
 def _main(infile: IO, outfile: IO) -> None:
@@ -60,12 +58,5 @@ def main(infile: IO, outfile: IO) -> None:
 
 
 if __name__ == "__main__":
-    # Read information about how to connect back to the JVM from the 
environment.
-    conn_info = os.environ.get(
-        "PYTHON_WORKER_FACTORY_SOCK_PATH", 
int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
-    )
-    auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
-    (sock_file, _) = local_connect_and_auth(conn_info, auth_secret)
-    write_int(os.getpid(), sock_file)
-    sock_file.flush()
-    main(sock_file, sock_file)
+    with get_sock_file_to_executor() as sock_file:
+        main(sock_file, sock_file)
diff --git a/python/pyspark/sql/worker/plan_data_source_read.py 
b/python/pyspark/sql/worker/plan_data_source_read.py
index ed1c602b0af4..c858a99462b1 100644
--- a/python/pyspark/sql/worker/plan_data_source_read.py
+++ b/python/pyspark/sql/worker/plan_data_source_read.py
@@ -15,7 +15,6 @@
 # limitations under the License.
 #
 
-import os
 import functools
 import pyarrow as pa
 from itertools import islice, chain
@@ -44,8 +43,8 @@ from pyspark.sql.types import (
     StructType,
 )
 from pyspark.sql.worker.utils import worker_run
-from pyspark.util import local_connect_and_auth
 from pyspark.worker_util import (
+    get_sock_file_to_executor,
     read_command,
     pickleSer,
     utf8_deserializer,
@@ -376,12 +375,5 @@ def main(infile: IO, outfile: IO) -> None:
 
 
 if __name__ == "__main__":
-    # Read information about how to connect back to the JVM from the 
environment.
-    conn_info = os.environ.get(
-        "PYTHON_WORKER_FACTORY_SOCK_PATH", 
int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
-    )
-    auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
-    (sock_file, _) = local_connect_and_auth(conn_info, auth_secret)
-    write_int(os.getpid(), sock_file)
-    sock_file.flush()
-    main(sock_file, sock_file)
+    with get_sock_file_to_executor() as sock_file:
+        main(sock_file, sock_file)
diff --git a/python/pyspark/sql/worker/python_streaming_sink_runner.py 
b/python/pyspark/sql/worker/python_streaming_sink_runner.py
index 952722d0d946..2a4ea0b95b28 100644
--- a/python/pyspark/sql/worker/python_streaming_sink_runner.py
+++ b/python/pyspark/sql/worker/python_streaming_sink_runner.py
@@ -15,7 +15,6 @@
 # limitations under the License.
 #
 
-import os
 from typing import IO
 
 from pyspark.errors import PySparkAssertionError
@@ -32,8 +31,8 @@ from pyspark.sql.types import (
     StructType,
 )
 from pyspark.sql.worker.utils import worker_run
-from pyspark.util import local_connect_and_auth
 from pyspark.worker_util import (
+    get_sock_file_to_executor,
     read_command,
     pickleSer,
     utf8_deserializer,
@@ -113,12 +112,5 @@ def main(infile: IO, outfile: IO) -> None:
 
 
 if __name__ == "__main__":
-    # Read information about how to connect back to the JVM from the 
environment.
-    conn_info = os.environ.get(
-        "PYTHON_WORKER_FACTORY_SOCK_PATH", 
int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
-    )
-    auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
-    (sock_file, _) = local_connect_and_auth(conn_info, auth_secret)
-    write_int(os.getpid(), sock_file)
-    sock_file.flush()
-    main(sock_file, sock_file)
+    with get_sock_file_to_executor() as sock_file:
+        main(sock_file, sock_file)
diff --git a/python/pyspark/sql/worker/write_into_data_source.py 
b/python/pyspark/sql/worker/write_into_data_source.py
index 111829bb7d58..83bdedb2fdbe 100644
--- a/python/pyspark/sql/worker/write_into_data_source.py
+++ b/python/pyspark/sql/worker/write_into_data_source.py
@@ -15,7 +15,6 @@
 # limitations under the License.
 #
 import inspect
-import os
 from typing import IO, Iterable, Iterator
 
 from pyspark.sql.conversion import ArrowTableToRowsConversion
@@ -24,7 +23,6 @@ from pyspark.logger.worker_io import capture_outputs
 from pyspark.serializers import (
     read_bool,
     read_int,
-    write_int,
 )
 from pyspark.sql import Row
 from pyspark.sql.datasource import (
@@ -43,10 +41,8 @@ from pyspark.sql.types import (
     _create_row,
 )
 from pyspark.sql.worker.utils import worker_run
-from pyspark.util import (
-    local_connect_and_auth,
-)
 from pyspark.worker_util import (
+    get_sock_file_to_executor,
     read_command,
     pickleSer,
     utf8_deserializer,
@@ -241,12 +237,5 @@ def main(infile: IO, outfile: IO) -> None:
 
 
 if __name__ == "__main__":
-    # Read information about how to connect back to the JVM from the 
environment.
-    conn_info = os.environ.get(
-        "PYTHON_WORKER_FACTORY_SOCK_PATH", 
int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
-    )
-    auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
-    (sock_file, _) = local_connect_and_auth(conn_info, auth_secret)
-    write_int(os.getpid(), sock_file)
-    sock_file.flush()
-    main(sock_file, sock_file)
+    with get_sock_file_to_executor() as sock_file:
+        main(sock_file, sock_file)
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 3717ba6d4d6a..a6ee3f25e486 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -36,7 +36,7 @@ from pyspark.accumulators import (
 from pyspark.sql.streaming.stateful_processor_api_client import 
StatefulProcessorApiClient
 from pyspark.sql.streaming.stateful_processor_util import 
TransformWithStateInPandasFuncMode
 from pyspark.taskcontext import BarrierTaskContext, TaskContext
-from pyspark.util import PythonEvalType, local_connect_and_auth
+from pyspark.util import PythonEvalType
 from pyspark.serializers import (
     write_int,
     read_long,
@@ -95,6 +95,7 @@ from pyspark import shuffle
 from pyspark.errors import PySparkRuntimeError, PySparkTypeError, 
PySparkValueError
 from pyspark.worker_util import (
     check_python_version,
+    get_sock_file_to_executor,
     read_command,
     pickleSer,
     send_accumulator_updates,
@@ -3449,13 +3450,5 @@ def main(infile, outfile):
 
 
 if __name__ == "__main__":
-    # Read information about how to connect back to the JVM from the 
environment.
-    conn_info = os.environ.get(
-        "PYTHON_WORKER_FACTORY_SOCK_PATH", 
int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
-    )
-    auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
-    (sock_file, _) = local_connect_and_auth(conn_info, auth_secret)
-    # TODO: Remove the following two lines and use `Process.pid()` when we 
drop JDK 8.
-    write_int(os.getpid(), sock_file)
-    sock_file.flush()
-    main(sock_file, sock_file)
+    with get_sock_file_to_executor() as sock_file:
+        main(sock_file, sock_file)
diff --git a/python/pyspark/worker_util.py b/python/pyspark/worker_util.py
index ac090bc955bd..ccf937b5e734 100644
--- a/python/pyspark/worker_util.py
+++ b/python/pyspark/worker_util.py
@@ -18,11 +18,12 @@
 """
 Util functions for workers.
 """
+from contextlib import contextmanager
 import importlib
 from inspect import currentframe, getframeinfo
 import os
 import sys
-from typing import Any, IO, Optional
+from typing import Any, Generator, IO, Optional
 import warnings
 
 if "SPARK_TESTING" in os.environ:
@@ -192,6 +193,25 @@ def setup_broadcasts(infile: IO) -> None:
         broadcast_sock_file.close()
 
 
+@contextmanager
+def get_sock_file_to_executor(timeout: Optional[int] = -1) -> Generator[IO, 
None, None]:
+    # Read information about how to connect back to the JVM from the 
environment.
+    conn_info = os.environ.get(
+        "PYTHON_WORKER_FACTORY_SOCK_PATH", 
int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
+    )
+    auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
+    sock_file, sock = local_connect_and_auth(conn_info, auth_secret)
+    if timeout is None or timeout > 0:
+        sock.settimeout(timeout)
+    # TODO: Remove the following two lines and use `Process.pid()` when we 
drop JDK 8.
+    write_int(os.getpid(), sock_file)
+    sock_file.flush()
+    try:
+        yield sock_file
+    finally:
+        sock_file.close()
+
+
 def send_accumulator_updates(outfile: IO) -> None:
     """
     Send the accumulator updates back to JVM.


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to