This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 4fcecfe17f2 [SPARK-44194][PYTHON][CORE] Add JobTag APIs to PySpark SparkContext 4fcecfe17f2 is described below commit 4fcecfe17f2d54e14ac204bbdd97104828bbf2af Author: Hyukjin Kwon <gurwls...@apache.org> AuthorDate: Tue Jul 4 12:47:25 2023 +0900 [SPARK-44194][PYTHON][CORE] Add JobTag APIs to PySpark SparkContext ### What changes were proposed in this pull request? This PR proposes to add: - `SparkContext.setInterruptOnCancel(interruptOnCancel: Boolean): Unit` - `SparkContext.addJobTag(tag: String): Unit` - `SparkContext.removeJobTag(tag: String): Unit` - `SparkContext.getJobTags(): Set[String]` - `SparkContext.clearJobTags(): Unit` - `SparkContext.cancelJobsWithTag(tag: String): Unit` into PySpark. See also SPARK-43952. ### Why are the changes needed? For PySpark users, and feature parity. ### Does this PR introduce _any_ user-facing change? Yes, it adds new API in PySpark. ### How was this patch tested? Unittests were added. Closes #41841 from HyukjinKwon/SPARK-44194. Authored-by: Hyukjin Kwon <gurwls...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/docs/source/reference/pyspark.rst | 6 ++ python/pyspark/context.py | 177 ++++++++++++++++++++++++++++++- python/pyspark/tests/test_pin_thread.py | 35 ++++-- 3 files changed, 207 insertions(+), 11 deletions(-) diff --git a/python/docs/source/reference/pyspark.rst b/python/docs/source/reference/pyspark.rst index ec3df071639..9a6fbb65171 100644 --- a/python/docs/source/reference/pyspark.rst +++ b/python/docs/source/reference/pyspark.rst @@ -55,6 +55,7 @@ Spark Context APIs SparkContext.accumulator SparkContext.addArchive SparkContext.addFile + SparkContext.addJobTag SparkContext.addPyFile SparkContext.applicationId SparkContext.binaryFiles @@ -62,12 +63,15 @@ Spark Context APIs SparkContext.broadcast SparkContext.cancelAllJobs SparkContext.cancelJobGroup + SparkContext.cancelJobsWithTag + SparkContext.clearJobTags SparkContext.defaultMinPartitions SparkContext.defaultParallelism SparkContext.dump_profiles SparkContext.emptyRDD SparkContext.getCheckpointDir SparkContext.getConf + SparkContext.getJobTags SparkContext.getLocalProperty SparkContext.getOrCreate SparkContext.hadoopFile @@ -80,9 +84,11 @@ Spark Context APIs SparkContext.pickleFile SparkContext.range SparkContext.resources + SparkContext.removeJobTag SparkContext.runJob SparkContext.sequenceFile SparkContext.setCheckpointDir + SparkContext.setInterruptOnCancel SparkContext.setJobDescription SparkContext.setJobGroup SparkContext.setLocalProperty diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 51a4db67e8c..4867ce2ae29 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -40,6 +40,7 @@ from typing import ( Type, TYPE_CHECKING, TypeVar, + Set, ) from py4j.java_collections import JavaMap @@ -2164,6 +2165,160 @@ class SparkContext: """ self._jsc.setJobGroup(groupId, description, interruptOnCancel) + def setInterruptOnCancel(self, interruptOnCancel: bool) -> None: + """ + Set the behavior of job cancellation from jobs started in this thread. + + .. versionadded:: 3.5.0 + + Parameters + ---------- + interruptOnCancel : bool + If true, then job cancellation will result in ``Thread.interrupt()`` + being called on the job's executor threads. This is useful to help ensure that + the tasks are actually stopped in a timely manner, but is off by default due to + HDFS-1208, where HDFS may respond to ``Thread.interrupt()`` by marking nodes as dead. + + See Also + -------- + :meth:`SparkContext.addJobTag` + :meth:`SparkContext.removeJobTag` + :meth:`SparkContext.cancelAllJobs` + :meth:`SparkContext.cancelJobGroup` + :meth:`SparkContext.cancelJobsWithTag` + """ + self._jsc.setInterruptOnCancel(interruptOnCancel) + + def addJobTag(self, tag: str) -> None: + """ + Add a tag to be assigned to all the jobs started by this thread. + + Parameters + ---------- + tag : str + The tag to be added. Cannot contain ',' (comma) character. + + See Also + -------- + :meth:`SparkContext.removeJobTag` + :meth:`SparkContext.getJobTags` + :meth:`SparkContext.clearJobTags` + :meth:`SparkContext.cancelJobsWithTag` + :meth:`SparkContext.setInterruptOnCancel` + + Examples + -------- + >>> import threading + >>> from time import sleep + >>> from pyspark import InheritableThread + >>> sc.setInterruptOnCancel(interruptOnCancel=True) + >>> result = "Not Set" + >>> lock = threading.Lock() + >>> def map_func(x): + ... sleep(100) + ... raise RuntimeError("Task should have been cancelled") + ... + >>> def start_job(x): + ... global result + ... try: + ... sc.addJobTag("job_to_cancel") + ... result = sc.parallelize(range(x)).map(map_func).collect() + ... except Exception as e: + ... result = "Cancelled" + ... lock.release() + ... + >>> def stop_job(): + ... sleep(5) + ... sc.cancelJobsWithTag("job_to_cancel") + ... + >>> suppress = lock.acquire() + >>> suppress = InheritableThread(target=start_job, args=(10,)).start() + >>> suppress = InheritableThread(target=stop_job).start() + >>> suppress = lock.acquire() + >>> print(result) + Cancelled + >>> sc.clearJobTags() + """ + self._jsc.addJobTag(tag) + + def removeJobTag(self, tag: str) -> None: + """ + Remove a tag previously added to be assigned to all the jobs started by this thread. + Noop if such a tag was not added earlier. + + Parameters + ---------- + tag : str + The tag to be removed. Cannot contain ',' (comma) character. + + See Also + -------- + :meth:`SparkContext.addJobTag` + :meth:`SparkContext.getJobTags` + :meth:`SparkContext.clearJobTags` + :meth:`SparkContext.cancelJobsWithTag` + :meth:`SparkContext.setInterruptOnCancel` + + Examples + -------- + >>> sc.addJobTag("job_to_cancel1") + >>> sc.addJobTag("job_to_cancel2") + >>> sc.getJobTags() + {'job_to_cancel1', 'job_to_cancel2'} + >>> sc.removeJobTag("job_to_cancel1") + >>> sc.getJobTags() + {'job_to_cancel2'} + >>> sc.clearJobTags() + """ + self._jsc.removeJobTag(tag) + + def getJobTags(self) -> Set[str]: + """ + Get the tags that are currently set to be assigned to all the jobs started by this thread. + + Returns + ------- + set of str + the tags that are currently set to be assigned to all the jobs started by this thread. + + See Also + -------- + :meth:`SparkContext.addJobTag` + :meth:`SparkContext.removeJobTag` + :meth:`SparkContext.clearJobTags` + :meth:`SparkContext.cancelJobsWithTag` + :meth:`SparkContext.setInterruptOnCancel` + + Examples + -------- + >>> sc.addJobTag("job_to_cancel") + >>> sc.getJobTags() + {'job_to_cancel'} + >>> sc.clearJobTags() + """ + return self._jsc.getJobTags() + + def clearJobTags(self) -> None: + """ + Clear the current thread's job tags. + + See Also + -------- + :meth:`SparkContext.addJobTag` + :meth:`SparkContext.removeJobTag` + :meth:`SparkContext.getJobTags` + :meth:`SparkContext.cancelJobsWithTag` + :meth:`SparkContext.setInterruptOnCancel` + + Examples + -------- + >>> sc.addJobTag("job_to_cancel") + >>> sc.clearJobTags() + >>> sc.getJobTags() + set() + """ + self._jsc.clearJobTags() + def setLocalProperty(self, key: str, value: str) -> None: """ Set a local property that affects jobs submitted from this thread, such as the @@ -2243,10 +2398,29 @@ class SparkContext: See Also -------- :meth:`SparkContext.setJobGroup` - :meth:`SparkContext.cancelJobGroup` """ self._jsc.sc().cancelJobGroup(groupId) + def cancelJobsWithTag(self, tag: str) -> None: + """ + Cancel active jobs that have the specified tag. See + :meth:`SparkContext.addJobTag`. + + Parameters + ---------- + tag : str + The tag to be cancelled. Cannot contain ',' (comma) character. + + See Also + -------- + :meth:`SparkContext.addJobTag` + :meth:`SparkContext.removeJobTag` + :meth:`SparkContext.getJobTags` + :meth:`SparkContext.clearJobTags` + :meth:`SparkContext.setInterruptOnCancel` + """ + return self._jsc.cancelJobsWithTag(tag) + def cancelAllJobs(self) -> None: """ Cancel all jobs that have been scheduled or are running. @@ -2256,6 +2430,7 @@ class SparkContext: See Also -------- :meth:`SparkContext.cancelJobGroup` + :meth:`SparkContext.cancelJobsWithTag` :meth:`SparkContext.runJob` """ self._jsc.sc().cancelAllJobs() diff --git a/python/pyspark/tests/test_pin_thread.py b/python/pyspark/tests/test_pin_thread.py index dd291b8a0cc..975b5498089 100644 --- a/python/pyspark/tests/test_pin_thread.py +++ b/python/pyspark/tests/test_pin_thread.py @@ -83,10 +83,25 @@ class PinThreadTests(unittest.TestCase): assert len(set(jvm_thread_ids)) == 10 def test_multiple_group_jobs(self): - # SPARK-22340 Add a mode to pin Python thread into JVM's - - group_a = "job_ids_to_cancel" - group_b = "job_ids_to_run" + # SPARK-22340: Add a mode to pin Python thread into JVM's + self.check_job_cancellation( + lambda job_group: self.sc.setJobGroup( + job_group, "test rdd collect with setting job group" + ), + lambda job_group: self.sc.cancelJobGroup(job_group), + ) + + def test_multiple_group_tags(self): + # SPARK-44194: Test pinned thread mode with job tags. + self.check_job_cancellation( + lambda job_tag: self.sc.addJobTag(job_tag), + lambda job_tag: self.sc.cancelJobsWithTag(job_tag), + ) + + def check_job_cancellation(self, setter, canceller): + + job_id_a = "job_ids_to_cancel" + job_id_b = "job_ids_to_run" threads = [] thread_ids = range(4) @@ -97,13 +112,13 @@ class PinThreadTests(unittest.TestCase): # The index of the array is the thread index which job run in. is_job_cancelled = [False for _ in thread_ids] - def run_job(job_group, index): + def run_job(job_id, index): """ Executes a job with the group ``job_group``. Each job waits for 3 seconds and then exits. """ try: - self.sc.setJobGroup(job_group, "test rdd collect with setting job group") + setter(job_id) self.sc.parallelize([15]).map(lambda x: time.sleep(x)).collect() is_job_cancelled[index] = False except Exception: @@ -111,24 +126,24 @@ class PinThreadTests(unittest.TestCase): is_job_cancelled[index] = True # Test if job succeeded when not cancelled. - run_job(group_a, 0) + run_job(job_id_a, 0) self.assertFalse(is_job_cancelled[0]) # Run jobs for i in thread_ids_to_cancel: - t = threading.Thread(target=run_job, args=(group_a, i)) + t = threading.Thread(target=run_job, args=(job_id_a, i)) t.start() threads.append(t) for i in thread_ids_to_run: - t = threading.Thread(target=run_job, args=(group_b, i)) + t = threading.Thread(target=run_job, args=(job_id_b, i)) t.start() threads.append(t) # Wait to make sure all jobs are executed. time.sleep(3) # And then, cancel one job group. - self.sc.cancelJobGroup(group_a) + canceller(job_id_a) # Wait until all threads launching jobs are finished. for t in threads: --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org