This is an automated email from the ASF dual-hosted git repository.

haejoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 459483acb45e [SPARK-50357][PYTHON] Support Interrupt(Tag|All) APIs for 
PySpark
459483acb45e is described below

commit 459483acb45e44592c81e0f449c09c4607a680a4
Author: Haejoon Lee <[email protected]>
AuthorDate: Mon Jan 6 16:33:19 2025 +0900

    [SPARK-50357][PYTHON] Support Interrupt(Tag|All) APIs for PySpark
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to support `Interrupt(Tag|All)` for PySpark
    
    ### Why are the changes needed?
    
    To improve the compatibility between Spark Connect and Spark Classic.
    
    ### Does this PR introduce _any_ user-facing change?
    
    New APIs are added
    - InterruptTag
    - InterruptAll
    
    ### How was this patch tested?
    
    Added UTs
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #49014 from itholic/SPARK-50357.
    
    Authored-by: Haejoon Lee <[email protected]>
    Signed-off-by: Haejoon Lee <[email protected]>
---
 .../source/reference/pyspark.sql/spark_session.rst |  4 +--
 python/pyspark/sql/session.py                      | 34 +++++++++++++++-------
 .../tests/connect/test_parity_job_cancellation.py  | 22 --------------
 .../sql/tests/test_connect_compatibility.py        |  2 --
 python/pyspark/sql/tests/test_job_cancellation.py  | 22 ++++++++++++++
 python/pyspark/sql/tests/test_session.py           |  1 -
 6 files changed, 48 insertions(+), 37 deletions(-)

diff --git a/python/docs/source/reference/pyspark.sql/spark_session.rst 
b/python/docs/source/reference/pyspark.sql/spark_session.rst
index 1677d3e8e020..a35fccbcffe9 100644
--- a/python/docs/source/reference/pyspark.sql/spark_session.rst
+++ b/python/docs/source/reference/pyspark.sql/spark_session.rst
@@ -52,6 +52,8 @@ See also :class:`SparkSession`.
     SparkSession.dataSource
     SparkSession.getActiveSession
     SparkSession.getTags
+    SparkSession.interruptAll
+    SparkSession.interruptTag
     SparkSession.newSession
     SparkSession.profile
     SparkSession.removeTag
@@ -86,8 +88,6 @@ Spark Connect Only
     SparkSession.clearProgressHandlers
     SparkSession.client
     SparkSession.copyFromLocalToFs
-    SparkSession.interruptAll
     SparkSession.interruptOperation
-    SparkSession.interruptTag
     SparkSession.registerProgressHandler
     SparkSession.removeProgressHandler
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index f3a1639fddaf..fc434cd16bfb 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -2197,13 +2197,15 @@ class SparkSession(SparkConversionMixin):
             messageParameters={"feature": "SparkSession.copyFromLocalToFs"},
         )
 
-    @remote_only
     def interruptAll(self) -> List[str]:
         """
         Interrupt all operations of this session currently running on the 
connected server.
 
         .. versionadded:: 3.5.0
 
+        .. versionchanged:: 4.0.0
+            Supports Spark Classic.
+
         Returns
         -------
         list of str
@@ -2213,18 +2215,25 @@ class SparkSession(SparkConversionMixin):
         -----
         There is still a possibility of operation finishing just as it is 
interrupted.
         """
-        raise PySparkRuntimeError(
-            errorClass="ONLY_SUPPORTED_WITH_SPARK_CONNECT",
-            messageParameters={"feature": "SparkSession.interruptAll"},
-        )
+        java_list = self._jsparkSession.interruptAll()
+        python_list = list()
+
+        # Use iterator to manually iterate through Java list
+        java_iterator = java_list.iterator()
+        while java_iterator.hasNext():
+            python_list.append(str(java_iterator.next()))
+
+        return python_list
 
-    @remote_only
     def interruptTag(self, tag: str) -> List[str]:
         """
         Interrupt all operations of this session with the given operation tag.
 
         .. versionadded:: 3.5.0
 
+        .. versionchanged:: 4.0.0
+            Supports Spark Classic.
+
         Returns
         -------
         list of str
@@ -2234,10 +2243,15 @@ class SparkSession(SparkConversionMixin):
         -----
         There is still a possibility of operation finishing just as it is 
interrupted.
         """
-        raise PySparkRuntimeError(
-            errorClass="ONLY_SUPPORTED_WITH_SPARK_CONNECT",
-            messageParameters={"feature": "SparkSession.interruptTag"},
-        )
+        java_list = self._jsparkSession.interruptTag(tag)
+        python_list = list()
+
+        # Use iterator to manually iterate through Java list
+        java_iterator = java_list.iterator()
+        while java_iterator.hasNext():
+            python_list.append(str(java_iterator.next()))
+
+        return python_list
 
     @remote_only
     def interruptOperation(self, op_id: str) -> List[str]:
diff --git a/python/pyspark/sql/tests/connect/test_parity_job_cancellation.py 
b/python/pyspark/sql/tests/connect/test_parity_job_cancellation.py
index c5184b04d6aa..ddb4554afa55 100644
--- a/python/pyspark/sql/tests/connect/test_parity_job_cancellation.py
+++ b/python/pyspark/sql/tests/connect/test_parity_job_cancellation.py
@@ -32,28 +32,6 @@ class JobCancellationParityTests(JobCancellationTestsMixin, 
ReusedConnectTestCas
             create_thread=lambda target, session: 
threading.Thread(target=func, args=(target,))
         )
 
-    def test_interrupt_tag(self):
-        thread_ids = range(4)
-        self.check_job_cancellation(
-            lambda job_group: self.spark.addTag(job_group),
-            lambda job_group: self.spark.interruptTag(job_group),
-            thread_ids,
-            [i for i in thread_ids if i % 2 == 0],
-            [i for i in thread_ids if i % 2 != 0],
-        )
-        self.spark.clearTags()
-
-    def test_interrupt_all(self):
-        thread_ids = range(4)
-        self.check_job_cancellation(
-            lambda job_group: None,
-            lambda job_group: self.spark.interruptAll(),
-            thread_ids,
-            thread_ids,
-            [],
-        )
-        self.spark.clearTags()
-
 
 if __name__ == "__main__":
     import unittest
diff --git a/python/pyspark/sql/tests/test_connect_compatibility.py 
b/python/pyspark/sql/tests/test_connect_compatibility.py
index ef83dc3834d0..25b8be1f9ac7 100644
--- a/python/pyspark/sql/tests/test_connect_compatibility.py
+++ b/python/pyspark/sql/tests/test_connect_compatibility.py
@@ -266,9 +266,7 @@ class ConnectCompatibilityTestsMixin:
             "addArtifacts",
             "clearProgressHandlers",
             "copyFromLocalToFs",
-            "interruptAll",
             "interruptOperation",
-            "interruptTag",
             "newSession",
             "registerProgressHandler",
             "removeProgressHandler",
diff --git a/python/pyspark/sql/tests/test_job_cancellation.py 
b/python/pyspark/sql/tests/test_job_cancellation.py
index a046c9c01811..3f30f7880889 100644
--- a/python/pyspark/sql/tests/test_job_cancellation.py
+++ b/python/pyspark/sql/tests/test_job_cancellation.py
@@ -166,6 +166,28 @@ class JobCancellationTestsMixin:
         self.assertEqual(first, {"a", "b"})
         self.assertEqual(second, {"a", "b", "c"})
 
+    def test_interrupt_tag(self):
+        thread_ids = range(4)
+        self.check_job_cancellation(
+            lambda job_group: self.spark.addTag(job_group),
+            lambda job_group: self.spark.interruptTag(job_group),
+            thread_ids,
+            [i for i in thread_ids if i % 2 == 0],
+            [i for i in thread_ids if i % 2 != 0],
+        )
+        self.spark.clearTags()
+
+    def test_interrupt_all(self):
+        thread_ids = range(4)
+        self.check_job_cancellation(
+            lambda job_group: None,
+            lambda job_group: self.spark.interruptAll(),
+            thread_ids,
+            thread_ids,
+            [],
+        )
+        self.spark.clearTags()
+
 
 class JobCancellationTests(JobCancellationTestsMixin, ReusedSQLTestCase):
     pass
diff --git a/python/pyspark/sql/tests/test_session.py 
b/python/pyspark/sql/tests/test_session.py
index 3fbc0be943e4..a22fe777e3c9 100644
--- a/python/pyspark/sql/tests/test_session.py
+++ b/python/pyspark/sql/tests/test_session.py
@@ -227,7 +227,6 @@ class SparkSessionTests3(unittest.TestCase, 
PySparkErrorTestUtils):
                 (lambda: session.client, "client"),
                 (session.addArtifacts, "addArtifact(s)"),
                 (lambda: session.copyFromLocalToFs("", ""), 
"copyFromLocalToFs"),
-                (lambda: session.interruptTag(""), "interruptTag"),
                 (lambda: session.interruptOperation(""), "interruptOperation"),
             ]
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to