This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 854628f84565 [SPARK-55326][PYTHON][CONNECT] Release remote session 
when SPARK_CONNECT_RELEASE_SESSION_ON_EXIT is set
854628f84565 is described below

commit 854628f84565ec0b7016ee4db1eb077765f58972
Author: Bobby Wang <[email protected]>
AuthorDate: Tue Mar 3 09:27:30 2026 +0900

    [SPARK-55326][PYTHON][CONNECT] Release remote session when 
SPARK_CONNECT_RELEASE_SESSION_ON_EXIT is set
    
    ### What changes were proposed in this pull request?
    
    This PR adds an _on_exit handler to SparkConnectClient that is registered 
with Python's atexit module. When enabled via the 
SPARK_CONNECT_RELEASE_SESSION_ON_EXIT environment variable, the client will 
automatically
    
    ### Why are the changes needed?
    
    Currently, when a PySpark Connect client process exits without explicitly 
calling `spark.stop()`, the session may remain active on the server side, 
consuming resources unnecessarily. This change provides an opt-in mechanism to 
automatically release the session during process exit
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. Users can now set the environment variable 
`SPARK_CONNECT_RELEASE_SESSION_ON_EXIT=true` to enable automatic session 
release when the Python process exits.
    
    ### How was this patch tested?
    
    Pass the CIs.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Yes, co-authored with claude-4.5-opus-high
    
    Closes #54106 from wbo4958/release-on-exit.
    
    Authored-by: Bobby Wang <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/sql/connect/client/core.py          |  19 ++-
 .../sql/tests/connect/client/test_client.py        | 138 +++++++++++++++++++++
 2 files changed, 155 insertions(+), 2 deletions(-)

diff --git a/python/pyspark/sql/connect/client/core.py 
b/python/pyspark/sql/connect/client/core.py
index aa060df24e41..ab7979a28326 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -755,8 +755,11 @@ class SparkConnectClient(object):
 
         self._release_futures: weakref.WeakSet[concurrent.futures.Future] = 
weakref.WeakSet()
 
-        # cleanup ml cache if possible
-        atexit.register(self._cleanup_ml_cache)
+        self._release_session_on_exit = os.getenv(
+            "SPARK_CONNECT_RELEASE_SESSION_ON_EXIT", "false"
+        ).lower() in ("true", "1")
+        # cleanup if possible
+        atexit.register(self._on_exit)
 
         self.global_user_context_extensions: List[Tuple[str, any_pb2.Any]] = []
         self.global_user_context_extensions_lock = threading.Lock()
@@ -2281,6 +2284,18 @@ class SparkConnectClient(object):
         except Exception:
             return []
 
+    def _on_exit(self) -> None:
+        self._cleanup_ml_cache()
+        if self._release_session_on_exit and not self._closed:
+            try:
+                self.release_session()
+            except Exception:
+                pass
+            try:
+                self.close()
+            except Exception:
+                pass
+
     def _cleanup_ml_cache(self) -> None:
         try:
             command = pb2.Command()
diff --git a/python/pyspark/sql/tests/connect/client/test_client.py 
b/python/pyspark/sql/tests/connect/client/test_client.py
index 55faff5e9ed3..85fbafe22728 100644
--- a/python/pyspark/sql/tests/connect/client/test_client.py
+++ b/python/pyspark/sql/tests/connect/client/test_client.py
@@ -450,6 +450,144 @@ class SparkConnectClientTestCase(unittest.TestCase):
         for resp in client._stub.ExecutePlan(req, metadata=None):
             assert resp.operation_id == "10a4c38e-7e87-40ee-9d6f-60ff0751e63b"
 
+    def test_on_exit_calls_release_and_close_when_enabled(self):
+        client = SparkConnectClient("sc://foo/", 
use_reattachable_execute=False)
+        client._release_session_on_exit = True
+        client._closed = False
+
+        call_tracker = {"release_session": 0, "close": 0}
+
+        def mock_release_session():
+            call_tracker["release_session"] += 1
+
+        def mock_close():
+            call_tracker["close"] += 1
+
+        client.release_session = mock_release_session
+        client.close = mock_close
+
+        client._on_exit()
+
+        self.assertEqual(call_tracker["release_session"], 1)
+        self.assertEqual(call_tracker["close"], 1)
+
+    def test_on_exit_does_not_call_when_release_disabled(self):
+        """Test _on_exit does nothing when _release_session_on_exit is 
False."""
+        client = SparkConnectClient("sc://foo/", 
use_reattachable_execute=False)
+        client._release_session_on_exit = False
+        client._closed = False
+
+        call_tracker = {"release_session": 0, "close": 0}
+
+        def mock_release_session():
+            call_tracker["release_session"] += 1
+
+        def mock_close():
+            call_tracker["close"] += 1
+
+        client.release_session = mock_release_session
+        client.close = mock_close
+
+        client._on_exit()
+
+        self.assertEqual(call_tracker["release_session"], 0)
+        self.assertEqual(call_tracker["close"], 0)
+
+    def test_on_exit_does_not_call_when_already_closed(self):
+        """Test _on_exit does nothing when client is already closed."""
+        client = SparkConnectClient("sc://foo/", 
use_reattachable_execute=False)
+        client._release_session_on_exit = True
+        client._closed = True
+
+        call_tracker = {"release_session": 0, "close": 0}
+
+        def mock_release_session():
+            call_tracker["release_session"] += 1
+
+        def mock_close():
+            call_tracker["close"] += 1
+
+        client.release_session = mock_release_session
+        client.close = mock_close
+
+        client._on_exit()
+
+        self.assertEqual(call_tracker["release_session"], 0)
+        self.assertEqual(call_tracker["close"], 0)
+
+    def test_on_exit_catches_release_session_exception(self):
+        """Test _on_exit continues to call close even if release_session 
raises."""
+        client = SparkConnectClient("sc://foo/", 
use_reattachable_execute=False)
+        client._release_session_on_exit = True
+        client._closed = False
+
+        call_tracker = {"release_session": 0, "close": 0}
+
+        def mock_release_session():
+            call_tracker["release_session"] += 1
+            raise Exception("release error")
+
+        def mock_close():
+            call_tracker["close"] += 1
+
+        client.release_session = mock_release_session
+        client.close = mock_close
+
+        # Should not raise
+        client._on_exit()
+
+        self.assertEqual(call_tracker["release_session"], 1)
+        self.assertEqual(call_tracker["close"], 1)
+
+    def test_on_exit_catches_close_exception(self):
+        """Test _on_exit silently catches exception from close."""
+        client = SparkConnectClient("sc://foo/", 
use_reattachable_execute=False)
+        client._release_session_on_exit = True
+        client._closed = False
+
+        call_tracker = {"release_session": 0, "close": 0}
+
+        def mock_release_session():
+            call_tracker["release_session"] += 1
+
+        def mock_close():
+            call_tracker["close"] += 1
+            raise Exception("close error")
+
+        client.release_session = mock_release_session
+        client.close = mock_close
+
+        # Should not raise
+        client._on_exit()
+
+        self.assertEqual(call_tracker["release_session"], 1)
+        self.assertEqual(call_tracker["close"], 1)
+
+    def test_on_exit_catches_both_exceptions(self):
+        """Test _on_exit handles both release_session and close raising 
exceptions."""
+        client = SparkConnectClient("sc://foo/", 
use_reattachable_execute=False)
+        client._release_session_on_exit = True
+        client._closed = False
+
+        call_tracker = {"release_session": 0, "close": 0}
+
+        def mock_release_session():
+            call_tracker["release_session"] += 1
+            raise Exception("release error")
+
+        def mock_close():
+            call_tracker["close"] += 1
+            raise Exception("close error")
+
+        client.release_session = mock_release_session
+        client.close = mock_close
+
+        # Should not raise
+        client._on_exit()
+
+        self.assertEqual(call_tracker["release_session"], 1)
+        self.assertEqual(call_tracker["close"], 1)
+
     def test_get_operations_statuses_all(self):
         """Test get_operations_statuses returns all operation statuses when no 
IDs specified."""
         OperationStatus = proto.GetStatusResponse.OperationStatus


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to