This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 4fcecfe17f2 [SPARK-44194][PYTHON][CORE] Add JobTag APIs to PySpark 
SparkContext
4fcecfe17f2 is described below

commit 4fcecfe17f2d54e14ac204bbdd97104828bbf2af
Author: Hyukjin Kwon <gurwls...@apache.org>
AuthorDate: Tue Jul 4 12:47:25 2023 +0900

    [SPARK-44194][PYTHON][CORE] Add JobTag APIs to PySpark SparkContext
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to add:
    
    - `SparkContext.setInterruptOnCancel(interruptOnCancel: Boolean): Unit`
    - `SparkContext.addJobTag(tag: String): Unit`
    - `SparkContext.removeJobTag(tag: String): Unit`
    - `SparkContext.getJobTags(): Set[String]`
    - `SparkContext.clearJobTags(): Unit`
    - `SparkContext.cancelJobsWithTag(tag: String): Unit`
    
    into PySpark.
    
    See also SPARK-43952.
    
    ### Why are the changes needed?
    
    For PySpark users, and feature parity.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, it adds new API in PySpark.
    
    ### How was this patch tested?
    
    Unittests were added.
    
    Closes #41841 from HyukjinKwon/SPARK-44194.
    
    Authored-by: Hyukjin Kwon <gurwls...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/docs/source/reference/pyspark.rst |   6 ++
 python/pyspark/context.py                | 177 ++++++++++++++++++++++++++++++-
 python/pyspark/tests/test_pin_thread.py  |  35 ++++--
 3 files changed, 207 insertions(+), 11 deletions(-)

diff --git a/python/docs/source/reference/pyspark.rst 
b/python/docs/source/reference/pyspark.rst
index ec3df071639..9a6fbb65171 100644
--- a/python/docs/source/reference/pyspark.rst
+++ b/python/docs/source/reference/pyspark.rst
@@ -55,6 +55,7 @@ Spark Context APIs
     SparkContext.accumulator
     SparkContext.addArchive
     SparkContext.addFile
+    SparkContext.addJobTag
     SparkContext.addPyFile
     SparkContext.applicationId
     SparkContext.binaryFiles
@@ -62,12 +63,15 @@ Spark Context APIs
     SparkContext.broadcast
     SparkContext.cancelAllJobs
     SparkContext.cancelJobGroup
+    SparkContext.cancelJobsWithTag
+    SparkContext.clearJobTags
     SparkContext.defaultMinPartitions
     SparkContext.defaultParallelism
     SparkContext.dump_profiles
     SparkContext.emptyRDD
     SparkContext.getCheckpointDir
     SparkContext.getConf
+    SparkContext.getJobTags
     SparkContext.getLocalProperty
     SparkContext.getOrCreate
     SparkContext.hadoopFile
@@ -80,9 +84,11 @@ Spark Context APIs
     SparkContext.pickleFile
     SparkContext.range
     SparkContext.resources
+    SparkContext.removeJobTag
     SparkContext.runJob
     SparkContext.sequenceFile
     SparkContext.setCheckpointDir
+    SparkContext.setInterruptOnCancel
     SparkContext.setJobDescription
     SparkContext.setJobGroup
     SparkContext.setLocalProperty
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 51a4db67e8c..4867ce2ae29 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -40,6 +40,7 @@ from typing import (
     Type,
     TYPE_CHECKING,
     TypeVar,
+    Set,
 )
 
 from py4j.java_collections import JavaMap
@@ -2164,6 +2165,160 @@ class SparkContext:
         """
         self._jsc.setJobGroup(groupId, description, interruptOnCancel)
 
+    def setInterruptOnCancel(self, interruptOnCancel: bool) -> None:
+        """
+        Set the behavior of job cancellation from jobs started in this thread.
+
+        .. versionadded:: 3.5.0
+
+        Parameters
+        ----------
+        interruptOnCancel : bool
+            If true, then job cancellation will result in 
``Thread.interrupt()``
+            being called on the job's executor threads. This is useful to help 
ensure that
+            the tasks are actually stopped in a timely manner, but is off by 
default due to
+            HDFS-1208, where HDFS may respond to ``Thread.interrupt()`` by 
marking nodes as dead.
+
+        See Also
+        --------
+        :meth:`SparkContext.addJobTag`
+        :meth:`SparkContext.removeJobTag`
+        :meth:`SparkContext.cancelAllJobs`
+        :meth:`SparkContext.cancelJobGroup`
+        :meth:`SparkContext.cancelJobsWithTag`
+        """
+        self._jsc.setInterruptOnCancel(interruptOnCancel)
+
+    def addJobTag(self, tag: str) -> None:
+        """
+        Add a tag to be assigned to all the jobs started by this thread.
+
+        Parameters
+        ----------
+        tag : str
+            The tag to be added. Cannot contain ',' (comma) character.
+
+        See Also
+        --------
+        :meth:`SparkContext.removeJobTag`
+        :meth:`SparkContext.getJobTags`
+        :meth:`SparkContext.clearJobTags`
+        :meth:`SparkContext.cancelJobsWithTag`
+        :meth:`SparkContext.setInterruptOnCancel`
+
+        Examples
+        --------
+        >>> import threading
+        >>> from time import sleep
+        >>> from pyspark import InheritableThread
+        >>> sc.setInterruptOnCancel(interruptOnCancel=True)
+        >>> result = "Not Set"
+        >>> lock = threading.Lock()
+        >>> def map_func(x):
+        ...     sleep(100)
+        ...     raise RuntimeError("Task should have been cancelled")
+        ...
+        >>> def start_job(x):
+        ...     global result
+        ...     try:
+        ...         sc.addJobTag("job_to_cancel")
+        ...         result = sc.parallelize(range(x)).map(map_func).collect()
+        ...     except Exception as e:
+        ...         result = "Cancelled"
+        ...     lock.release()
+        ...
+        >>> def stop_job():
+        ...     sleep(5)
+        ...     sc.cancelJobsWithTag("job_to_cancel")
+        ...
+        >>> suppress = lock.acquire()
+        >>> suppress = InheritableThread(target=start_job, args=(10,)).start()
+        >>> suppress = InheritableThread(target=stop_job).start()
+        >>> suppress = lock.acquire()
+        >>> print(result)
+        Cancelled
+        >>> sc.clearJobTags()
+        """
+        self._jsc.addJobTag(tag)
+
+    def removeJobTag(self, tag: str) -> None:
+        """
+        Remove a tag previously added to be assigned to all the jobs started 
by this thread.
+        Noop if such a tag was not added earlier.
+
+        Parameters
+        ----------
+        tag : str
+            The tag to be removed. Cannot contain ',' (comma) character.
+
+        See Also
+        --------
+        :meth:`SparkContext.addJobTag`
+        :meth:`SparkContext.getJobTags`
+        :meth:`SparkContext.clearJobTags`
+        :meth:`SparkContext.cancelJobsWithTag`
+        :meth:`SparkContext.setInterruptOnCancel`
+
+        Examples
+        --------
+        >>> sc.addJobTag("job_to_cancel1")
+        >>> sc.addJobTag("job_to_cancel2")
+        >>> sc.getJobTags()
+        {'job_to_cancel1', 'job_to_cancel2'}
+        >>> sc.removeJobTag("job_to_cancel1")
+        >>> sc.getJobTags()
+        {'job_to_cancel2'}
+        >>> sc.clearJobTags()
+        """
+        self._jsc.removeJobTag(tag)
+
+    def getJobTags(self) -> Set[str]:
+        """
+        Get the tags that are currently set to be assigned to all the jobs 
started by this thread.
+
+        Returns
+        -------
+        set of str
+            the tags that are currently set to be assigned to all the jobs 
started by this thread.
+
+        See Also
+        --------
+        :meth:`SparkContext.addJobTag`
+        :meth:`SparkContext.removeJobTag`
+        :meth:`SparkContext.clearJobTags`
+        :meth:`SparkContext.cancelJobsWithTag`
+        :meth:`SparkContext.setInterruptOnCancel`
+
+        Examples
+        --------
+        >>> sc.addJobTag("job_to_cancel")
+        >>> sc.getJobTags()
+        {'job_to_cancel'}
+        >>> sc.clearJobTags()
+        """
+        return self._jsc.getJobTags()
+
+    def clearJobTags(self) -> None:
+        """
+        Clear the current thread's job tags.
+
+        See Also
+        --------
+        :meth:`SparkContext.addJobTag`
+        :meth:`SparkContext.removeJobTag`
+        :meth:`SparkContext.getJobTags`
+        :meth:`SparkContext.cancelJobsWithTag`
+        :meth:`SparkContext.setInterruptOnCancel`
+
+        Examples
+        --------
+        >>> sc.addJobTag("job_to_cancel")
+        >>> sc.clearJobTags()
+        >>> sc.getJobTags()
+        set()
+        """
+        self._jsc.clearJobTags()
+
     def setLocalProperty(self, key: str, value: str) -> None:
         """
         Set a local property that affects jobs submitted from this thread, 
such as the
@@ -2243,10 +2398,29 @@ class SparkContext:
         See Also
         --------
         :meth:`SparkContext.setJobGroup`
-        :meth:`SparkContext.cancelJobGroup`
         """
         self._jsc.sc().cancelJobGroup(groupId)
 
+    def cancelJobsWithTag(self, tag: str) -> None:
+        """
+        Cancel active jobs that have the specified tag. See
+        :meth:`SparkContext.addJobTag`.
+
+        Parameters
+        ----------
+        tag : str
+            The tag to be cancelled. Cannot contain ',' (comma) character.
+
+        See Also
+        --------
+        :meth:`SparkContext.addJobTag`
+        :meth:`SparkContext.removeJobTag`
+        :meth:`SparkContext.getJobTags`
+        :meth:`SparkContext.clearJobTags`
+        :meth:`SparkContext.setInterruptOnCancel`
+        """
+        return self._jsc.cancelJobsWithTag(tag)
+
     def cancelAllJobs(self) -> None:
         """
         Cancel all jobs that have been scheduled or are running.
@@ -2256,6 +2430,7 @@ class SparkContext:
         See Also
         --------
         :meth:`SparkContext.cancelJobGroup`
+        :meth:`SparkContext.cancelJobsWithTag`
         :meth:`SparkContext.runJob`
         """
         self._jsc.sc().cancelAllJobs()
diff --git a/python/pyspark/tests/test_pin_thread.py 
b/python/pyspark/tests/test_pin_thread.py
index dd291b8a0cc..975b5498089 100644
--- a/python/pyspark/tests/test_pin_thread.py
+++ b/python/pyspark/tests/test_pin_thread.py
@@ -83,10 +83,25 @@ class PinThreadTests(unittest.TestCase):
         assert len(set(jvm_thread_ids)) == 10
 
     def test_multiple_group_jobs(self):
-        # SPARK-22340 Add a mode to pin Python thread into JVM's
-
-        group_a = "job_ids_to_cancel"
-        group_b = "job_ids_to_run"
+        # SPARK-22340: Add a mode to pin Python thread into JVM's
+        self.check_job_cancellation(
+            lambda job_group: self.sc.setJobGroup(
+                job_group, "test rdd collect with setting job group"
+            ),
+            lambda job_group: self.sc.cancelJobGroup(job_group),
+        )
+
+    def test_multiple_group_tags(self):
+        # SPARK-44194: Test pinned thread mode with job tags.
+        self.check_job_cancellation(
+            lambda job_tag: self.sc.addJobTag(job_tag),
+            lambda job_tag: self.sc.cancelJobsWithTag(job_tag),
+        )
+
+    def check_job_cancellation(self, setter, canceller):
+
+        job_id_a = "job_ids_to_cancel"
+        job_id_b = "job_ids_to_run"
 
         threads = []
         thread_ids = range(4)
@@ -97,13 +112,13 @@ class PinThreadTests(unittest.TestCase):
         # The index of the array is the thread index which job run in.
         is_job_cancelled = [False for _ in thread_ids]
 
-        def run_job(job_group, index):
+        def run_job(job_id, index):
             """
             Executes a job with the group ``job_group``. Each job waits for 3 
seconds
             and then exits.
             """
             try:
-                self.sc.setJobGroup(job_group, "test rdd collect with setting 
job group")
+                setter(job_id)
                 self.sc.parallelize([15]).map(lambda x: 
time.sleep(x)).collect()
                 is_job_cancelled[index] = False
             except Exception:
@@ -111,24 +126,24 @@ class PinThreadTests(unittest.TestCase):
                 is_job_cancelled[index] = True
 
         # Test if job succeeded when not cancelled.
-        run_job(group_a, 0)
+        run_job(job_id_a, 0)
         self.assertFalse(is_job_cancelled[0])
 
         # Run jobs
         for i in thread_ids_to_cancel:
-            t = threading.Thread(target=run_job, args=(group_a, i))
+            t = threading.Thread(target=run_job, args=(job_id_a, i))
             t.start()
             threads.append(t)
 
         for i in thread_ids_to_run:
-            t = threading.Thread(target=run_job, args=(group_b, i))
+            t = threading.Thread(target=run_job, args=(job_id_b, i))
             t.start()
             threads.append(t)
 
         # Wait to make sure all jobs are executed.
         time.sleep(3)
         # And then, cancel one job group.
-        self.sc.cancelJobGroup(group_a)
+        canceller(job_id_a)
 
         # Wait until all threads launching jobs are finished.
         for t in threads:


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to