This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 2cb1d3bf878 [SPARK-44211][PYTHON][CONNECT] Implement 
SparkSession.is_stopped
2cb1d3bf878 is described below

commit 2cb1d3bf87808896e7ef6467f7afb21d1e0a50fb
Author: Alice Sayutina <[email protected]>
AuthorDate: Mon Jul 3 09:18:53 2023 +0900

    [SPARK-44211][PYTHON][CONNECT] Implement SparkSession.is_stopped
    
    ### What changes were proposed in this pull request?
    Creates SparkSession.is_stopped, which returns if this session was stopped 
previously.
    
    ### Why are the changes needed?
    It's not possible to determine if the session was closed right now
    
    ### Does this PR introduce _any_ user-facing change?
    Introduces is_stopped property
    
    ### How was this patch tested?
    Unit Tests
    
    Closes #41760 from cdkrot/master.
    
    Authored-by: Alice Sayutina <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/sql/connect/client/core.py              | 9 +++++++++
 python/pyspark/sql/connect/session.py                  | 7 +++++++
 python/pyspark/sql/tests/connect/client/test_client.py | 7 +++++++
 python/pyspark/sql/tests/connect/test_session.py       | 7 +++++++
 4 files changed, 30 insertions(+)

diff --git a/python/pyspark/sql/connect/client/core.py 
b/python/pyspark/sql/connect/client/core.py
index 7368521259a..f8d304e9ccc 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -594,6 +594,7 @@ class SparkConnectClient(object):
             self._user_id = os.getenv("USER", None)
 
         self._channel = self._builder.toChannel()
+        self._closed = False
         self._stub = grpc_lib.SparkConnectServiceStub(self._channel)
         self._artifact_manager = ArtifactManager(self._user_id, 
self._session_id, self._channel)
         # Configure logging for the SparkConnect client.
@@ -835,6 +836,14 @@ class SparkConnectClient(object):
         Close the channel.
         """
         self._channel.close()
+        self._closed = True
+
+    @property
+    def is_closed(self) -> bool:
+        """
+        Returns if the channel was closed previously using close() method
+        """
+        return self._closed
 
     @property
     def host(self) -> str:
diff --git a/python/pyspark/sql/connect/session.py 
b/python/pyspark/sql/connect/session.py
index 356dacd8e18..358674c9189 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -573,6 +573,13 @@ class SparkSession:
 
     stop.__doc__ = PySparkSession.stop.__doc__
 
+    @property
+    def is_stopped(self) -> bool:
+        """
+        Returns if this session was stopped
+        """
+        return self.client.is_closed
+
     @classmethod
     def getActiveSession(cls) -> Any:
         raise PySparkNotImplementedError(
diff --git a/python/pyspark/sql/tests/connect/client/test_client.py 
b/python/pyspark/sql/tests/connect/client/test_client.py
index 3e3ce6f40df..5c39d4502f5 100644
--- a/python/pyspark/sql/tests/connect/client/test_client.py
+++ b/python/pyspark/sql/tests/connect/client/test_client.py
@@ -79,6 +79,13 @@ class SparkConnectClientTestCase(unittest.TestCase):
         client.interrupt_all()
         self.assertIsNotNone(mock.req, "Interrupt API was not called when 
expected")
 
+    def test_is_closed(self):
+        client = SparkConnectClient("sc://foo/;token=bar")
+
+        self.assertFalse(client.is_closed)
+        client.close()
+        self.assertTrue(client.is_closed)
+
 
 class MockService:
     # Simplest mock of the SparkConnectService.
diff --git a/python/pyspark/sql/tests/connect/test_session.py 
b/python/pyspark/sql/tests/connect/test_session.py
index 3cf6d91d404..bde22d80303 100644
--- a/python/pyspark/sql/tests/connect/test_session.py
+++ b/python/pyspark/sql/tests/connect/test_session.py
@@ -56,3 +56,10 @@ class SparkSessionTestCase(unittest.TestCase):
         test_session.stop()
 
         self.assertEqual("other", host)
+
+    def test_session_stop(self):
+        session = RemoteSparkSession.builder.remote("sc://other").getOrCreate()
+
+        self.assertFalse(session.is_stopped)
+        session.stop()
+        self.assertTrue(session.is_stopped)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to