Repository: spark
Updated Branches:
  refs/heads/branch-2.1 c1a26b458 -> f07e989c0


[SPARK-18928] Check TaskContext.isInterrupted() in FileScanRDD, JDBCRDD & 
UnsafeSorter

## What changes were proposed in this pull request?

In order to respond to task cancellation, Spark tasks must periodically check 
`TaskContext.isInterrupted()`, but this check is missing on a few critical read 
paths used in Spark SQL, including `FileScanRDD`, `JDBCRDD`, and 
UnsafeSorter-based sorts. This can cause interrupted / cancelled tasks to 
continue running and become zombies (as also described in #16189).

This patch aims to fix this problem by adding `TaskContext.isInterrupted()` 
checks to these paths. Note that I could have used `InterruptibleIterator` to 
simply wrap a bunch of iterators but in some cases this would have an adverse 
performance penalty or might not be effective due to certain special uses of 
Iterators in Spark SQL. Instead, I inlined `InterruptibleIterator`-style logic 
into existing iterator subclasses.

## How was this patch tested?

Tested manually in `spark-shell` with two different reproductions of 
non-cancellable tasks, one involving scans of huge files and another involving 
sort-merge joins that spill to disk. Both causes of zombie tasks are fixed by 
the changes added here.

Author: Josh Rosen <[email protected]>

Closes #16340 from JoshRosen/sql-task-interruption.

(cherry picked from commit 5857b9ac2d9808d9b89a5b29620b5052e2beebf5)
Signed-off-by: Herman van Hovell <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f07e989c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f07e989c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f07e989c

Branch: refs/heads/branch-2.1
Commit: f07e989c02844151587f9a29fe77ea65facea422
Parents: c1a26b4
Author: Josh Rosen <[email protected]>
Authored: Tue Dec 20 01:19:38 2016 +0100
Committer: Herman van Hovell <[email protected]>
Committed: Tue Dec 20 01:19:51 2016 +0100

----------------------------------------------------------------------
 .../collection/unsafe/sort/UnsafeInMemorySorter.java    | 11 +++++++++++
 .../collection/unsafe/sort/UnsafeSorterSpillReader.java | 11 +++++++++++
 .../spark/sql/execution/datasources/FileScanRDD.scala   | 12 ++++++++++--
 .../spark/sql/execution/datasources/jdbc/JDBCRDD.scala  |  5 +++--
 4 files changed, 35 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f07e989c/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
----------------------------------------------------------------------
diff --git 
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
 
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
index 252a35e..5b42843 100644
--- 
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
+++ 
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
@@ -22,6 +22,8 @@ import java.util.LinkedList;
 
 import org.apache.avro.reflect.Nullable;
 
+import org.apache.spark.TaskContext;
+import org.apache.spark.TaskKilledException;
 import org.apache.spark.memory.MemoryConsumer;
 import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.unsafe.Platform;
@@ -253,6 +255,7 @@ public final class UnsafeInMemorySorter {
     private long keyPrefix;
     private int recordLength;
     private long currentPageNumber;
+    private final TaskContext taskContext = TaskContext.get();
 
     private SortedIterator(int numRecords, int offset) {
       this.numRecords = numRecords;
@@ -283,6 +286,14 @@ public final class UnsafeInMemorySorter {
 
     @Override
     public void loadNext() {
+      // Kill the task in case it has been marked as killed. This logic is from
+      // InterruptibleIterator, but we inline it here instead of wrapping the 
iterator in order
+      // to avoid performance overhead. This check is added here in 
`loadNext()` instead of in
+      // `hasNext()` because it's technically possible for the caller to be 
relying on
+      // `getNumRecords()` instead of `hasNext()` to know when to stop.
+      if (taskContext != null && taskContext.isInterrupted()) {
+        throw new TaskKilledException();
+      }
       // This pointer points to a 4-byte record length, followed by the 
record's bytes
       final long recordPointer = array.get(offset + position);
       currentPageNumber = TaskMemoryManager.decodePageNumber(recordPointer);

http://git-wip-us.apache.org/repos/asf/spark/blob/f07e989c/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
----------------------------------------------------------------------
diff --git 
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
 
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
index a658e5e..b6323c6 100644
--- 
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
+++ 
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
@@ -23,6 +23,8 @@ import com.google.common.io.ByteStreams;
 import com.google.common.io.Closeables;
 
 import org.apache.spark.SparkEnv;
+import org.apache.spark.TaskContext;
+import org.apache.spark.TaskKilledException;
 import org.apache.spark.io.NioBufferedFileInputStream;
 import org.apache.spark.serializer.SerializerManager;
 import org.apache.spark.storage.BlockId;
@@ -51,6 +53,7 @@ public final class UnsafeSorterSpillReader extends 
UnsafeSorterIterator implemen
   private byte[] arr = new byte[1024 * 1024];
   private Object baseObject = arr;
   private final long baseOffset = Platform.BYTE_ARRAY_OFFSET;
+  private final TaskContext taskContext = TaskContext.get();
 
   public UnsafeSorterSpillReader(
       SerializerManager serializerManager,
@@ -94,6 +97,14 @@ public final class UnsafeSorterSpillReader extends 
UnsafeSorterIterator implemen
 
   @Override
   public void loadNext() throws IOException {
+    // Kill the task in case it has been marked as killed. This logic is from
+    // InterruptibleIterator, but we inline it here instead of wrapping the 
iterator in order
+    // to avoid performance overhead. This check is added here in `loadNext()` 
instead of in
+    // `hasNext()` because it's technically possible for the caller to be 
relying on
+    // `getNumRecords()` instead of `hasNext()` to know when to stop.
+    if (taskContext != null && taskContext.isInterrupted()) {
+      throw new TaskKilledException();
+    }
     recordLength = din.readInt();
     keyPrefix = din.readLong();
     if (recordLength > arr.length) {

http://git-wip-us.apache.org/repos/asf/spark/blob/f07e989c/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
index 69338f7..b926b92 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
@@ -21,7 +21,7 @@ import java.io.IOException
 
 import scala.collection.mutable
 
-import org.apache.spark.{Partition => RDDPartition, TaskContext}
+import org.apache.spark.{Partition => RDDPartition, TaskContext, 
TaskKilledException}
 import org.apache.spark.deploy.SparkHadoopUtil
 import org.apache.spark.rdd.{InputFileNameHolder, RDD}
 import org.apache.spark.sql.SparkSession
@@ -99,7 +99,15 @@ class FileScanRDD(
       private[this] var currentFile: PartitionedFile = null
       private[this] var currentIterator: Iterator[Object] = null
 
-      def hasNext: Boolean = (currentIterator != null && 
currentIterator.hasNext) || nextIterator()
+      def hasNext: Boolean = {
+        // Kill the task in case it has been marked as killed. This logic is 
from
+        // InterruptibleIterator, but we inline it here instead of wrapping 
the iterator in order
+        // to avoid performance overhead.
+        if (context.isInterrupted()) {
+          throw new TaskKilledException
+        }
+        (currentIterator != null && currentIterator.hasNext) || nextIterator()
+      }
       def next(): Object = {
         val nextElement = currentIterator.next()
         // TODO: we should have a better separation of row based and batch 
based scan, so that we

http://git-wip-us.apache.org/repos/asf/spark/blob/f07e989c/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
index d5b11e7..2bdc432 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
@@ -23,7 +23,7 @@ import scala.util.control.NonFatal
 
 import org.apache.commons.lang3.StringUtils
 
-import org.apache.spark.{Partition, SparkContext, TaskContext}
+import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, 
TaskContext}
 import org.apache.spark.internal.Logging
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
@@ -301,6 +301,7 @@ private[jdbc] class JDBCRDD(
     rs = stmt.executeQuery()
     val rowsIterator = JdbcUtils.resultSetToSparkInternalRows(rs, schema, 
inputMetrics)
 
-    CompletionIterator[InternalRow, Iterator[InternalRow]](rowsIterator, 
close())
+    CompletionIterator[InternalRow, Iterator[InternalRow]](
+      new InterruptibleIterator(context, rowsIterator), close())
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to