This is an automated email from the ASF dual-hosted git repository.
haejoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 459483acb45e [SPARK-50357][PYTHON] Support Interrupt(Tag|All) APIs for
PySpark
459483acb45e is described below
commit 459483acb45e44592c81e0f449c09c4607a680a4
Author: Haejoon Lee <[email protected]>
AuthorDate: Mon Jan 6 16:33:19 2025 +0900
[SPARK-50357][PYTHON] Support Interrupt(Tag|All) APIs for PySpark
### What changes were proposed in this pull request?
This PR proposes to support `Interrupt(Tag|All)` for PySpark
### Why are the changes needed?
To improve the compatibility between Spark Connect and Spark Classic.
### Does this PR introduce _any_ user-facing change?
New APIs are added
- InterruptTag
- InterruptAll
### How was this patch tested?
Added UTs
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #49014 from itholic/SPARK-50357.
Authored-by: Haejoon Lee <[email protected]>
Signed-off-by: Haejoon Lee <[email protected]>
---
.../source/reference/pyspark.sql/spark_session.rst | 4 +--
python/pyspark/sql/session.py | 34 +++++++++++++++-------
.../tests/connect/test_parity_job_cancellation.py | 22 --------------
.../sql/tests/test_connect_compatibility.py | 2 --
python/pyspark/sql/tests/test_job_cancellation.py | 22 ++++++++++++++
python/pyspark/sql/tests/test_session.py | 1 -
6 files changed, 48 insertions(+), 37 deletions(-)
diff --git a/python/docs/source/reference/pyspark.sql/spark_session.rst
b/python/docs/source/reference/pyspark.sql/spark_session.rst
index 1677d3e8e020..a35fccbcffe9 100644
--- a/python/docs/source/reference/pyspark.sql/spark_session.rst
+++ b/python/docs/source/reference/pyspark.sql/spark_session.rst
@@ -52,6 +52,8 @@ See also :class:`SparkSession`.
SparkSession.dataSource
SparkSession.getActiveSession
SparkSession.getTags
+ SparkSession.interruptAll
+ SparkSession.interruptTag
SparkSession.newSession
SparkSession.profile
SparkSession.removeTag
@@ -86,8 +88,6 @@ Spark Connect Only
SparkSession.clearProgressHandlers
SparkSession.client
SparkSession.copyFromLocalToFs
- SparkSession.interruptAll
SparkSession.interruptOperation
- SparkSession.interruptTag
SparkSession.registerProgressHandler
SparkSession.removeProgressHandler
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index f3a1639fddaf..fc434cd16bfb 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -2197,13 +2197,15 @@ class SparkSession(SparkConversionMixin):
messageParameters={"feature": "SparkSession.copyFromLocalToFs"},
)
- @remote_only
def interruptAll(self) -> List[str]:
"""
Interrupt all operations of this session currently running on the
connected server.
.. versionadded:: 3.5.0
+ .. versionchanged:: 4.0.0
+ Supports Spark Classic.
+
Returns
-------
list of str
@@ -2213,18 +2215,25 @@ class SparkSession(SparkConversionMixin):
-----
There is still a possibility of operation finishing just as it is
interrupted.
"""
- raise PySparkRuntimeError(
- errorClass="ONLY_SUPPORTED_WITH_SPARK_CONNECT",
- messageParameters={"feature": "SparkSession.interruptAll"},
- )
+ java_list = self._jsparkSession.interruptAll()
+ python_list = list()
+
+ # Use iterator to manually iterate through Java list
+ java_iterator = java_list.iterator()
+ while java_iterator.hasNext():
+ python_list.append(str(java_iterator.next()))
+
+ return python_list
- @remote_only
def interruptTag(self, tag: str) -> List[str]:
"""
Interrupt all operations of this session with the given operation tag.
.. versionadded:: 3.5.0
+ .. versionchanged:: 4.0.0
+ Supports Spark Classic.
+
Returns
-------
list of str
@@ -2234,10 +2243,15 @@ class SparkSession(SparkConversionMixin):
-----
There is still a possibility of operation finishing just as it is
interrupted.
"""
- raise PySparkRuntimeError(
- errorClass="ONLY_SUPPORTED_WITH_SPARK_CONNECT",
- messageParameters={"feature": "SparkSession.interruptTag"},
- )
+ java_list = self._jsparkSession.interruptTag(tag)
+ python_list = list()
+
+ # Use iterator to manually iterate through Java list
+ java_iterator = java_list.iterator()
+ while java_iterator.hasNext():
+ python_list.append(str(java_iterator.next()))
+
+ return python_list
@remote_only
def interruptOperation(self, op_id: str) -> List[str]:
diff --git a/python/pyspark/sql/tests/connect/test_parity_job_cancellation.py
b/python/pyspark/sql/tests/connect/test_parity_job_cancellation.py
index c5184b04d6aa..ddb4554afa55 100644
--- a/python/pyspark/sql/tests/connect/test_parity_job_cancellation.py
+++ b/python/pyspark/sql/tests/connect/test_parity_job_cancellation.py
@@ -32,28 +32,6 @@ class JobCancellationParityTests(JobCancellationTestsMixin,
ReusedConnectTestCas
create_thread=lambda target, session:
threading.Thread(target=func, args=(target,))
)
- def test_interrupt_tag(self):
- thread_ids = range(4)
- self.check_job_cancellation(
- lambda job_group: self.spark.addTag(job_group),
- lambda job_group: self.spark.interruptTag(job_group),
- thread_ids,
- [i for i in thread_ids if i % 2 == 0],
- [i for i in thread_ids if i % 2 != 0],
- )
- self.spark.clearTags()
-
- def test_interrupt_all(self):
- thread_ids = range(4)
- self.check_job_cancellation(
- lambda job_group: None,
- lambda job_group: self.spark.interruptAll(),
- thread_ids,
- thread_ids,
- [],
- )
- self.spark.clearTags()
-
if __name__ == "__main__":
import unittest
diff --git a/python/pyspark/sql/tests/test_connect_compatibility.py
b/python/pyspark/sql/tests/test_connect_compatibility.py
index ef83dc3834d0..25b8be1f9ac7 100644
--- a/python/pyspark/sql/tests/test_connect_compatibility.py
+++ b/python/pyspark/sql/tests/test_connect_compatibility.py
@@ -266,9 +266,7 @@ class ConnectCompatibilityTestsMixin:
"addArtifacts",
"clearProgressHandlers",
"copyFromLocalToFs",
- "interruptAll",
"interruptOperation",
- "interruptTag",
"newSession",
"registerProgressHandler",
"removeProgressHandler",
diff --git a/python/pyspark/sql/tests/test_job_cancellation.py
b/python/pyspark/sql/tests/test_job_cancellation.py
index a046c9c01811..3f30f7880889 100644
--- a/python/pyspark/sql/tests/test_job_cancellation.py
+++ b/python/pyspark/sql/tests/test_job_cancellation.py
@@ -166,6 +166,28 @@ class JobCancellationTestsMixin:
self.assertEqual(first, {"a", "b"})
self.assertEqual(second, {"a", "b", "c"})
+ def test_interrupt_tag(self):
+ thread_ids = range(4)
+ self.check_job_cancellation(
+ lambda job_group: self.spark.addTag(job_group),
+ lambda job_group: self.spark.interruptTag(job_group),
+ thread_ids,
+ [i for i in thread_ids if i % 2 == 0],
+ [i for i in thread_ids if i % 2 != 0],
+ )
+ self.spark.clearTags()
+
+ def test_interrupt_all(self):
+ thread_ids = range(4)
+ self.check_job_cancellation(
+ lambda job_group: None,
+ lambda job_group: self.spark.interruptAll(),
+ thread_ids,
+ thread_ids,
+ [],
+ )
+ self.spark.clearTags()
+
class JobCancellationTests(JobCancellationTestsMixin, ReusedSQLTestCase):
pass
diff --git a/python/pyspark/sql/tests/test_session.py
b/python/pyspark/sql/tests/test_session.py
index 3fbc0be943e4..a22fe777e3c9 100644
--- a/python/pyspark/sql/tests/test_session.py
+++ b/python/pyspark/sql/tests/test_session.py
@@ -227,7 +227,6 @@ class SparkSessionTests3(unittest.TestCase,
PySparkErrorTestUtils):
(lambda: session.client, "client"),
(session.addArtifacts, "addArtifact(s)"),
(lambda: session.copyFromLocalToFs("", ""),
"copyFromLocalToFs"),
- (lambda: session.interruptTag(""), "interruptTag"),
(lambda: session.interruptOperation(""), "interruptOperation"),
]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]