Repository: spark Updated Branches: refs/heads/branch-2.0 678d91c1d -> 2aae220b5
[SPARK-18928][BRANCH-2.0] Check TaskContext.isInterrupted() in FileScanRDD, JDBCRDD & UnsafeSorter This is a branch-2.0 backport of #16340; the original description follows: ## What changes were proposed in this pull request? In order to respond to task cancellation, Spark tasks must periodically check `TaskContext.isInterrupted()`, but this check is missing on a few critical read paths used in Spark SQL, including `FileScanRDD`, `JDBCRDD`, and UnsafeSorter-based sorts. This can cause interrupted / cancelled tasks to continue running and become zombies (as also described in #16189). This patch aims to fix this problem by adding `TaskContext.isInterrupted()` checks to these paths. Note that I could have used `InterruptibleIterator` to simply wrap a bunch of iterators but in some cases this would have an adverse performance penalty or might not be effective due to certain special uses of Iterators in Spark SQL. Instead, I inlined `InterruptibleIterator`-style logic into existing iterator subclasses. ## How was this patch tested? Tested manually in `spark-shell` with two different reproductions of non-cancellable tasks, one involving scans of huge files and another involving sort-merge joins that spill to disk. Both causes of zombie tasks are fixed by the changes added here. Author: Josh Rosen <[email protected]> Closes #16357 from JoshRosen/sql-task-interruption-branch-2.0. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2aae220b Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2aae220b Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2aae220b Branch: refs/heads/branch-2.0 Commit: 2aae220b536065f55b2cf644a2a223aab0d051d0 Parents: 678d91c Author: Josh Rosen <[email protected]> Authored: Tue Dec 20 16:05:04 2016 -0800 Committer: Yin Huai <[email protected]> Committed: Tue Dec 20 16:05:04 2016 -0800 ---------------------------------------------------------------------- .../collection/unsafe/sort/UnsafeInMemorySorter.java | 11 +++++++++++ .../collection/unsafe/sort/UnsafeSorterSpillReader.java | 11 +++++++++++ .../spark/sql/execution/datasources/FileScanRDD.scala | 12 ++++++++++-- .../spark/sql/execution/datasources/jdbc/JDBCRDD.scala | 9 ++++++++- 4 files changed, 40 insertions(+), 3 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/2aae220b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index b517371..2bd756f 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -21,6 +21,8 @@ import java.util.Comparator; import org.apache.avro.reflect.Nullable; +import org.apache.spark.TaskContext; +import org.apache.spark.TaskKilledException; import org.apache.spark.memory.MemoryConsumer; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.unsafe.Platform; @@ -226,6 +228,7 @@ public final class UnsafeInMemorySorter { private long keyPrefix; private int recordLength; private long currentPageNumber; + private final TaskContext taskContext = TaskContext.get(); private SortedIterator(int numRecords, int offset) { this.numRecords = numRecords; @@ -256,6 +259,14 @@ public final class UnsafeInMemorySorter { @Override public void loadNext() { + // Kill the task in case it has been marked as killed. This logic is from + // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order + // to avoid performance overhead. This check is added here in `loadNext()` instead of in + // `hasNext()` because it's technically possible for the caller to be relying on + // `getNumRecords()` instead of `hasNext()` to know when to stop. + if (taskContext != null && taskContext.isInterrupted()) { + throw new TaskKilledException(); + } // This pointer points to a 4-byte record length, followed by the record's bytes final long recordPointer = array.get(offset + position); currentPageNumber = TaskMemoryManager.decodePageNumber(recordPointer); http://git-wip-us.apache.org/repos/asf/spark/blob/2aae220b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index 1d588c3..a3f04de 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -22,6 +22,8 @@ import java.io.*; import com.google.common.io.ByteStreams; import com.google.common.io.Closeables; +import org.apache.spark.TaskContext; +import org.apache.spark.TaskKilledException; import org.apache.spark.serializer.SerializerManager; import org.apache.spark.storage.BlockId; import org.apache.spark.unsafe.Platform; @@ -44,6 +46,7 @@ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implemen private byte[] arr = new byte[1024 * 1024]; private Object baseObject = arr; private final long baseOffset = Platform.BYTE_ARRAY_OFFSET; + private final TaskContext taskContext = TaskContext.get(); public UnsafeSorterSpillReader( SerializerManager serializerManager, @@ -73,6 +76,14 @@ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implemen @Override public void loadNext() throws IOException { + // Kill the task in case it has been marked as killed. This logic is from + // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order + // to avoid performance overhead. This check is added here in `loadNext()` instead of in + // `hasNext()` because it's technically possible for the caller to be relying on + // `getNumRecords()` instead of `hasNext()` to know when to stop. + if (taskContext != null && taskContext.isInterrupted()) { + throw new TaskKilledException(); + } recordLength = din.readInt(); keyPrefix = din.readLong(); if (recordLength > arr.length) { http://git-wip-us.apache.org/repos/asf/spark/blob/2aae220b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala index 1314c94..b4deaf1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources import scala.collection.mutable -import org.apache.spark.{Partition => RDDPartition, TaskContext} +import org.apache.spark.{Partition => RDDPartition, TaskContext, TaskKilledException} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.{InputFileNameHolder, RDD} import org.apache.spark.sql.SparkSession @@ -88,7 +88,15 @@ class FileScanRDD( private[this] var currentFile: PartitionedFile = null private[this] var currentIterator: Iterator[Object] = null - def hasNext = (currentIterator != null && currentIterator.hasNext) || nextIterator() + def hasNext: Boolean = { + // Kill the task in case it has been marked as killed. This logic is from + // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order + // to avoid performance overhead. + if (context.isInterrupted()) { + throw new TaskKilledException + } + (currentIterator != null && currentIterator.hasNext) || nextIterator() + } def next() = { val nextElement = currentIterator.next() // TODO: we should have a better separation of row based and batch based scan, so that we http://git-wip-us.apache.org/repos/asf/spark/blob/2aae220b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 95058cc..a0afe34 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -24,7 +24,7 @@ import scala.util.control.NonFatal import org.apache.commons.lang3.StringUtils -import org.apache.spark.{Partition, SparkContext, TaskContext} +import org.apache.spark.{Partition, SparkContext, TaskContext, TaskKilledException} import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -541,6 +541,13 @@ private[jdbc] class JDBCRDD( } override def hasNext: Boolean = { + // Kill the task in case it has been marked as killed. This logic is from + // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order + // to avoid performance overhead and to minimize modified code since it's not easy to + // wrap this Iterator without re-indenting tons of code. + if (context.isInterrupted()) { + throw new TaskKilledException + } if (!finished) { if (!gotNext) { nextValue = getNext() --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
