Repository: spark
Updated Branches:
  refs/heads/master 2d34183b2 -> 5b21139db


[SPARK-10530][CORE] Kill other task attempts when one taskattempt belonging the 
same task is succeeded in speculation

## What changes were proposed in this pull request?

With this patch, TaskSetManager kills other running attempts when any one of 
the attempt succeeds for the same task. Also killed tasks will not be 
considered as failed tasks and they get listed separately in the UI and also 
shows the task state as KILLED instead of FAILED.

## How was this patch tested?

core\src\test\scala\org\apache\spark\ui\jobs\JobProgressListenerSuite.scala
core\src\test\scala\org\apache\spark\util\JsonProtocolSuite.scala

I have verified this patch manually by enabling spark.speculation as true, when 
any attempt gets succeeded then other running attempts are getting killed for 
the same task and other pending tasks are getting assigned in those. And also 
when any attempt gets killed then they are considered as KILLED tasks and not 
considered as FAILED tasks. Please find the attached screen shots for the 
reference.

![stage-tasks-table](https://cloud.githubusercontent.com/assets/3174804/14075132/394c6a12-f4f4-11e5-8638-20ff7b8cc9bc.png)
![stages-table](https://cloud.githubusercontent.com/assets/3174804/14075134/3b60f412-f4f4-11e5-9ea6-dd0dcc86eb03.png)

Ref : https://github.com/apache/spark/pull/11916

Author: Devaraj K <[email protected]>

Closes #11996 from devaraj-kavali/SPARK-10530.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5b21139d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5b21139d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5b21139d

Branch: refs/heads/master
Commit: 5b21139dbf3bd09cb3a590bd0ffb857ea92dc23c
Parents: 2d34183
Author: Devaraj K <[email protected]>
Authored: Mon May 30 14:29:27 2016 -0700
Committer: Kay Ousterhout <[email protected]>
Committed: Mon May 30 14:29:27 2016 -0700

----------------------------------------------------------------------
 .../scala/org/apache/spark/scheduler/Pool.scala |  4 +-
 .../apache/spark/scheduler/Schedulable.scala    |  2 +-
 .../org/apache/spark/scheduler/TaskInfo.scala   | 20 +++++---
 .../spark/scheduler/TaskSchedulerImpl.scala     |  7 ++-
 .../apache/spark/scheduler/TaskSetManager.scala | 18 +++++--
 .../scala/org/apache/spark/ui/UIUtils.scala     |  2 +
 .../org/apache/spark/ui/jobs/AllJobsPage.scala  |  2 +-
 .../apache/spark/ui/jobs/ExecutorTable.scala    |  4 +-
 .../spark/ui/jobs/JobProgressListener.scala     |  7 +++
 .../org/apache/spark/ui/jobs/StageTable.scala   |  2 +-
 .../scala/org/apache/spark/ui/jobs/UIData.scala |  3 ++
 .../org/apache/spark/util/JsonProtocol.scala    |  3 ++
 .../spark/scheduler/TaskSetManagerSuite.scala   | 50 ++++++++++++++++++++
 .../org/apache/spark/ui/StagePageSuite.scala    |  2 +-
 .../org/apache/spark/ui/UIUtilsSuite.scala      |  2 +-
 .../ui/jobs/JobProgressListenerSuite.scala      |  6 ++-
 .../apache/spark/util/JsonProtocolSuite.scala   |  5 ++
 .../spark/streaming/ui/AllBatchesTable.scala    |  1 +
 .../apache/spark/streaming/ui/BatchPage.scala   |  1 +
 19 files changed, 119 insertions(+), 22 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5b21139d/core/src/main/scala/org/apache/spark/scheduler/Pool.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala 
b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala
index 5987cfe..732c89c 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala
@@ -88,10 +88,10 @@ private[spark] class Pool(
     schedulableQueue.asScala.foreach(_.executorLost(executorId, host, reason))
   }
 
-  override def checkSpeculatableTasks(): Boolean = {
+  override def checkSpeculatableTasks(minTimeToSpeculation: Int): Boolean = {
     var shouldRevive = false
     for (schedulable <- schedulableQueue.asScala) {
-      shouldRevive |= schedulable.checkSpeculatableTasks()
+      shouldRevive |= schedulable.checkSpeculatableTasks(minTimeToSpeculation)
     }
     shouldRevive
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/5b21139d/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala 
b/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala
index ab00bc8..b6f88ed 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala
@@ -43,6 +43,6 @@ private[spark] trait Schedulable {
   def removeSchedulable(schedulable: Schedulable): Unit
   def getSchedulableByName(name: String): Schedulable
   def executorLost(executorId: String, host: String, reason: 
ExecutorLossReason): Unit
-  def checkSpeculatableTasks(): Boolean
+  def checkSpeculatableTasks(minTimeToSpeculation: Int): Boolean
   def getSortedTaskSetQueue: ArrayBuffer[TaskSetManager]
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/5b21139d/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala 
b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
index a42990a..2d89232 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
@@ -19,6 +19,8 @@ package org.apache.spark.scheduler
 
 import scala.collection.mutable.ListBuffer
 
+import org.apache.spark.TaskState
+import org.apache.spark.TaskState.TaskState
 import org.apache.spark.annotation.DeveloperApi
 
 /**
@@ -58,24 +60,26 @@ class TaskInfo(
 
   var failed = false
 
+  var killed = false
+
   private[spark] def markGettingResult(time: Long = System.currentTimeMillis) {
     gettingResultTime = time
   }
 
-  private[spark] def markSuccessful(time: Long = System.currentTimeMillis) {
+  private[spark] def markFinished(state: TaskState, time: Long = 
System.currentTimeMillis) {
     finishTime = time
-  }
-
-  private[spark] def markFailed(time: Long = System.currentTimeMillis) {
-    finishTime = time
-    failed = true
+    if (state == TaskState.FAILED) {
+      failed = true
+    } else if (state == TaskState.KILLED) {
+      killed = true
+    }
   }
 
   def gettingResult: Boolean = gettingResultTime != 0
 
   def finished: Boolean = finishTime != 0
 
-  def successful: Boolean = finished && !failed
+  def successful: Boolean = finished && !failed && !killed
 
   def running: Boolean = !finished
 
@@ -88,6 +92,8 @@ class TaskInfo(
       }
     } else if (failed) {
       "FAILED"
+    } else if (killed) {
+      "KILLED"
     } else if (successful) {
       "SUCCESS"
     } else {

http://git-wip-us.apache.org/repos/asf/spark/blob/5b21139d/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala 
b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index 01e85ca..5cb1af9 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -65,6 +65,11 @@ private[spark] class TaskSchedulerImpl(
   // How often to check for speculative tasks
   val SPECULATION_INTERVAL_MS = conf.getTimeAsMs("spark.speculation.interval", 
"100ms")
 
+  // Duplicate copies of a task will only be launched if the original copy has 
been running for
+  // at least this amount of time. This is to avoid the overhead of launching 
speculative copies
+  // of tasks that are very short.
+  val MIN_TIME_TO_SPECULATION = 100
+
   private val speculationScheduler =
     
ThreadUtils.newDaemonSingleThreadScheduledExecutor("task-scheduler-speculation")
 
@@ -463,7 +468,7 @@ private[spark] class TaskSchedulerImpl(
   def checkSpeculatableTasks() {
     var shouldRevive = false
     synchronized {
-      shouldRevive = rootPool.checkSpeculatableTasks()
+      shouldRevive = rootPool.checkSpeculatableTasks(MIN_TIME_TO_SPECULATION)
     }
     if (shouldRevive) {
       backend.reviveOffers()

http://git-wip-us.apache.org/repos/asf/spark/blob/5b21139d/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala 
b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index 08d33f6..2eedd20 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -608,7 +608,7 @@ private[spark] class TaskSetManager(
   def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]): Unit = {
     val info = taskInfos(tid)
     val index = info.index
-    info.markSuccessful()
+    info.markFinished(TaskState.FINISHED)
     removeRunningTask(tid)
     // This method is called by "TaskSchedulerImpl.handleSuccessfulTask" which 
holds the
     // "TaskSchedulerImpl" lock until exiting. To avoid the SPARK-7655 issue, 
we should not
@@ -617,6 +617,14 @@ private[spark] class TaskSetManager(
     // Note: "result.value()" only deserializes the value when it's called at 
the first time, so
     // here "result.value()" just returns the value and won't block other 
threads.
     sched.dagScheduler.taskEnded(tasks(index), Success, result.value(), 
result.accumUpdates, info)
+    // Kill any other attempts for the same task (since those are unnecessary 
now that one
+    // attempt completed successfully).
+    for (attemptInfo <- taskAttempts(index) if attemptInfo.running) {
+      logInfo(s"Killing attempt ${attemptInfo.attemptNumber} for task 
${attemptInfo.id} " +
+        s"in stage ${taskSet.id} (TID ${attemptInfo.taskId}) on 
${attemptInfo.host} " +
+        s"as the attempt ${info.attemptNumber} succeeded on ${info.host}")
+      sched.backend.killTask(attemptInfo.taskId, attemptInfo.executorId, true)
+    }
     if (!successful(index)) {
       tasksSuccessful += 1
       logInfo("Finished task %s in stage %s (TID %d) in %d ms on %s 
(%d/%d)".format(
@@ -640,11 +648,11 @@ private[spark] class TaskSetManager(
    */
   def handleFailedTask(tid: Long, state: TaskState, reason: TaskEndReason) {
     val info = taskInfos(tid)
-    if (info.failed) {
+    if (info.failed || info.killed) {
       return
     }
     removeRunningTask(tid)
-    info.markFailed()
+    info.markFinished(state)
     val index = info.index
     copiesRunning(index) -= 1
     var accumUpdates: Seq[AccumulatorV2[_, _]] = Seq.empty
@@ -821,7 +829,7 @@ private[spark] class TaskSetManager(
    * TODO: To make this scale to large jobs, we need to maintain a list of 
running tasks, so that
    * we don't scan the whole task set. It might also help to make this sorted 
by launch time.
    */
-  override def checkSpeculatableTasks(): Boolean = {
+  override def checkSpeculatableTasks(minTimeToSpeculation: Int): Boolean = {
     // Can't speculate if we only have one task, and no need to speculate if 
the task set is a
     // zombie.
     if (isZombie || numTasks == 1) {
@@ -835,7 +843,7 @@ private[spark] class TaskSetManager(
       val durations = 
taskInfos.values.filter(_.successful).map(_.duration).toArray
       Arrays.sort(durations)
       val medianDuration = durations(min((0.5 * tasksSuccessful).round.toInt, 
durations.length - 1))
-      val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100)
+      val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 
minTimeToSpeculation)
       // TODO: Threshold should also look at standard deviation of task 
durations and have a lower
       // bound based on that.
       logDebug("Task length threshold for speculation: " + threshold)

http://git-wip-us.apache.org/repos/asf/spark/blob/5b21139d/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala 
b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
index 1aa85d6..4e2fe5e 100644
--- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
@@ -337,6 +337,7 @@ private[spark] object UIUtils extends Logging {
       completed: Int,
       failed: Int,
       skipped: Int,
+      killed: Int,
       total: Int): Seq[Node] = {
     val completeWidth = "width: %s%%".format((completed.toDouble/total)*100)
     // started + completed can be > total when there are speculative tasks
@@ -348,6 +349,7 @@ private[spark] object UIUtils extends Logging {
         {completed}/{total}
         { if (failed > 0) s"($failed failed)" }
         { if (skipped > 0) s"($skipped skipped)" }
+        { if (killed > 0) s"($killed killed)" }
       </span>
       <div class="bar bar-completed" style={completeWidth}></div>
       <div class="bar bar-running" style={startWidth}></div>

http://git-wip-us.apache.org/repos/asf/spark/blob/5b21139d/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala 
b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
index 373c26b..035d706 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
@@ -256,7 +256,7 @@ private[ui] class AllJobsPage(parent: JobsTab) extends 
WebUIPage("") {
         </td>
         <td class="progress-cell">
           {UIUtils.makeProgressBar(started = job.numActiveTasks, completed = 
job.numCompletedTasks,
-           failed = job.numFailedTasks, skipped = job.numSkippedTasks,
+           failed = job.numFailedTasks, skipped = job.numSkippedTasks, killed 
= job.numKilledTasks,
            total = job.numTasks - job.numSkippedTasks)}
         </td>
       </tr>

http://git-wip-us.apache.org/repos/asf/spark/blob/5b21139d/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala 
b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala
index f609fb4..293f143 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala
@@ -57,6 +57,7 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: 
Int, parent: Stage
         <th>Task Time</th>
         <th>Total Tasks</th>
         <th>Failed Tasks</th>
+        <th>Killed Tasks</th>
         <th>Succeeded Tasks</th>
         {if (hasInput) {
           <th>
@@ -116,8 +117,9 @@ private[ui] class ExecutorTable(stageId: Int, 
stageAttemptId: Int, parent: Stage
             <td>{k}</td>
             <td>{executorIdToAddress.getOrElse(k, "CANNOT FIND ADDRESS")}</td>
             <td 
sorttable_customkey={v.taskTime.toString}>{UIUtils.formatDuration(v.taskTime)}</td>
-            <td>{v.failedTasks + v.succeededTasks}</td>
+            <td>{v.failedTasks + v.succeededTasks + v.killedTasks}</td>
             <td>{v.failedTasks}</td>
+            <td>{v.killedTasks}</td>
             <td>{v.succeededTasks}</td>
             {if (stageData.hasInput) {
               <td sorttable_customkey={v.inputBytes.toString}>

http://git-wip-us.apache.org/repos/asf/spark/blob/5b21139d/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala 
b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
index 842f42b..c882740 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
@@ -369,6 +369,8 @@ class JobProgressListener(conf: SparkConf) extends 
SparkListener with Logging {
       taskEnd.reason match {
         case Success =>
           execSummary.succeededTasks += 1
+        case TaskKilled =>
+          execSummary.killedTasks += 1
         case _ =>
           execSummary.failedTasks += 1
       }
@@ -381,6 +383,9 @@ class JobProgressListener(conf: SparkConf) extends 
SparkListener with Logging {
             stageData.completedIndices.add(info.index)
             stageData.numCompleteTasks += 1
             None
+          case TaskKilled =>
+            stageData.numKilledTasks += 1
+            Some(TaskKilled.toErrorString)
           case e: ExceptionFailure => // Handle ExceptionFailure because we 
might have accumUpdates
             stageData.numFailedTasks += 1
             Some(e.toErrorString)
@@ -409,6 +414,8 @@ class JobProgressListener(conf: SparkConf) extends 
SparkListener with Logging {
         taskEnd.reason match {
           case Success =>
             jobData.numCompletedTasks += 1
+          case TaskKilled =>
+            jobData.numKilledTasks += 1
           case _ =>
             jobData.numFailedTasks += 1
         }

http://git-wip-us.apache.org/repos/asf/spark/blob/5b21139d/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala 
b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
index 2a1c3c1..0e02015 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
@@ -195,7 +195,7 @@ private[ui] class StageTableBase(
     <td class="progress-cell">
       {UIUtils.makeProgressBar(started = stageData.numActiveTasks,
         completed = stageData.completedIndices.size, failed = 
stageData.numFailedTasks,
-        skipped = 0, total = s.numTasks)}
+        skipped = 0, killed = stageData.numKilledTasks, total = s.numTasks)}
     </td>
     <td sorttable_customkey={inputRead.toString}>{inputReadWithUnit}</td>
     <td sorttable_customkey={outputWrite.toString}>{outputWriteWithUnit}</td>

http://git-wip-us.apache.org/repos/asf/spark/blob/5b21139d/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala 
b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala
index d76a0e6..20dde7c 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala
@@ -33,6 +33,7 @@ private[spark] object UIData {
     var taskTime : Long = 0
     var failedTasks : Int = 0
     var succeededTasks : Int = 0
+    var killedTasks : Int = 0
     var inputBytes : Long = 0
     var inputRecords : Long = 0
     var outputBytes : Long = 0
@@ -63,6 +64,7 @@ private[spark] object UIData {
     var numCompletedTasks: Int = 0,
     var numSkippedTasks: Int = 0,
     var numFailedTasks: Int = 0,
+    var numKilledTasks: Int = 0,
     /* Stages */
     var numActiveStages: Int = 0,
     // This needs to be a set instead of a simple count to prevent 
double-counting of rerun stages:
@@ -76,6 +78,7 @@ private[spark] object UIData {
     var numCompleteTasks: Int = _
     var completedIndices = new OpenHashSet[Int]()
     var numFailedTasks: Int = _
+    var numKilledTasks: Int = _
 
     var executorRunTime: Long = _
 

http://git-wip-us.apache.org/repos/asf/spark/blob/5b21139d/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala 
b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
index 18547d4..022b226 100644
--- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
@@ -280,6 +280,7 @@ private[spark] object JsonProtocol {
     ("Getting Result Time" -> taskInfo.gettingResultTime) ~
     ("Finish Time" -> taskInfo.finishTime) ~
     ("Failed" -> taskInfo.failed) ~
+    ("Killed" -> taskInfo.killed) ~
     ("Accumulables" -> 
JArray(taskInfo.accumulables.map(accumulableInfoToJson).toList))
   }
 
@@ -697,6 +698,7 @@ private[spark] object JsonProtocol {
     val gettingResultTime = (json \ "Getting Result Time").extract[Long]
     val finishTime = (json \ "Finish Time").extract[Long]
     val failed = (json \ "Failed").extract[Boolean]
+    val killed = (json \ "Killed").extractOpt[Boolean].getOrElse(false)
     val accumulables = (json \ "Accumulables").extractOpt[Seq[JValue]] match {
       case Some(values) => values.map(accumulableInfoFromJson)
       case None => Seq[AccumulableInfo]()
@@ -707,6 +709,7 @@ private[spark] object JsonProtocol {
     taskInfo.gettingResultTime = gettingResultTime
     taskInfo.finishTime = finishTime
     taskInfo.failed = failed
+    taskInfo.killed = killed
     accumulables.foreach { taskInfo.accumulables += _ }
     taskInfo
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/5b21139d/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala 
b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index 9b7b945..1d7c8f4 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -22,6 +22,8 @@ import java.util.Random
 import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 
+import org.mockito.Mockito.{mock, verify}
+
 import org.apache.spark._
 import org.apache.spark.internal.Logging
 import org.apache.spark.util.{AccumulatorV2, ManualClock}
@@ -789,6 +791,54 @@ class TaskSetManagerSuite extends SparkFunSuite with 
LocalSparkContext with Logg
     assert(TaskLocation("executor_host1_3") === 
ExecutorCacheTaskLocation("host1", "3"))
   }
 
+  test("Kill other task attempts when one attempt belonging to the same task 
succeeds") {
+    sc = new SparkContext("local", "test")
+    val sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", 
"host2"))
+    val taskSet = FakeTask.createTaskSet(4)
+    // Set the speculation multiplier to be 0 so speculative tasks are 
launched immediately
+    sc.conf.set("spark.speculation.multiplier", "0.0")
+    val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES)
+    val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] = 
taskSet.tasks.map { task =>
+      task.metrics.internalAccums
+    }
+    // Offer resources for 4 tasks to start
+    for ((k, v) <- List(
+        "exec1" -> "host1",
+        "exec1" -> "host1",
+        "exec2" -> "host2",
+        "exec2" -> "host2")) {
+      val taskOption = manager.resourceOffer(k, v, NO_PREF)
+      assert(taskOption.isDefined)
+      val task = taskOption.get
+      assert(task.executorId === k)
+    }
+    assert(sched.startedTasks.toSet === Set(0, 1, 2, 3))
+    // Complete the 3 tasks and leave 1 task in running
+    for (id <- Set(0, 1, 2)) {
+      manager.handleSuccessfulTask(id, createTaskResult(id, 
accumUpdatesByTask(id)))
+      assert(sched.endedTasks(id) === Success)
+    }
+
+    assert(manager.checkSpeculatableTasks(0))
+    // Offer resource to start the speculative attempt for the running task
+    val taskOption5 = manager.resourceOffer("exec1", "host1", NO_PREF)
+    assert(taskOption5.isDefined)
+    val task5 = taskOption5.get
+    assert(task5.index === 3)
+    assert(task5.taskId === 4)
+    assert(task5.executorId === "exec1")
+    assert(task5.attemptNumber === 1)
+    sched.backend = mock(classOf[SchedulerBackend])
+    // Complete the speculative attempt for the running task
+    manager.handleSuccessfulTask(4, createTaskResult(3, accumUpdatesByTask(3)))
+    // Verify that it kills other running attempt
+    verify(sched.backend).killTask(3, "exec2", true)
+    // Because the SchedulerBackend was a mock, the 2nd copy of the task won't 
actually be
+    // killed, so the FakeTaskScheduler is only told about the successful 
completion
+    // of the speculated task.
+    assert(sched.endedTasks(3) === Success)
+  }
+
   private def createTaskResult(
       id: Int,
       accumUpdates: Seq[AccumulatorV2[_, _]] = Seq.empty): 
DirectTaskResult[Int] = {

http://git-wip-us.apache.org/repos/asf/spark/blob/5b21139d/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala 
b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala
index b83ffa3..6d726d3 100644
--- a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala
@@ -83,7 +83,7 @@ class StagePageSuite extends SparkFunSuite with 
LocalSparkContext {
         val taskInfo = new TaskInfo(taskId, taskId, 0, 0, "0", "localhost", 
TaskLocality.ANY, false)
         jobListener.onStageSubmitted(SparkListenerStageSubmitted(stageInfo))
         jobListener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo))
-        taskInfo.markSuccessful()
+        taskInfo.markFinished(TaskState.FINISHED)
         val taskMetrics = TaskMetrics.empty
         taskMetrics.incPeakExecutionMemory(peakExecutionMemory)
         jobListener.onTaskEnd(

http://git-wip-us.apache.org/repos/asf/spark/blob/5b21139d/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala 
b/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala
index 58beaf1..6335d90 100644
--- a/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala
@@ -110,7 +110,7 @@ class UIUtilsSuite extends SparkFunSuite {
   }
 
   test("SPARK-11906: Progress bar should not overflow because of speculative 
tasks") {
-    val generated = makeProgressBar(2, 3, 0, 0, 4).head.child.filter(_.label 
== "div")
+    val generated = makeProgressBar(2, 3, 0, 0, 0, 
4).head.child.filter(_.label == "div")
     val expected = Seq(
       <div class="bar bar-completed" style="width: 75.0%"></div>,
       <div class="bar bar-running" style="width: 25.0%"></div>

http://git-wip-us.apache.org/repos/asf/spark/blob/5b21139d/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala 
b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
index 1fa9b28..edab727 100644
--- 
a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
+++ 
b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
@@ -243,7 +243,6 @@ class JobProgressListenerSuite extends SparkFunSuite with 
LocalSparkContext with
       new FetchFailed(null, 0, 0, 0, "ignored"),
       ExceptionFailure("Exception", "description", null, null, None),
       TaskResultLost,
-      TaskKilled,
       ExecutorLostFailure("0", true, Some("Induced failure")),
       UnknownReason)
     var failCount = 0
@@ -255,6 +254,11 @@ class JobProgressListenerSuite extends SparkFunSuite with 
LocalSparkContext with
       assert(listener.stageIdToData((task.stageId, 0)).numFailedTasks === 
failCount)
     }
 
+    // Make sure killed tasks are accounted for correctly.
+    listener.onTaskEnd(
+      SparkListenerTaskEnd(task.stageId, 0, taskType, TaskKilled, taskInfo, 
metrics))
+    assert(listener.stageIdToData((task.stageId, 0)).numKilledTasks === 1)
+
     // Make sure we count success as success.
     listener.onTaskEnd(
       SparkListenerTaskEnd(task.stageId, 1, taskType, Success, taskInfo, 
metrics))

http://git-wip-us.apache.org/repos/asf/spark/blob/5b21139d/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala 
b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
index 6fda737..0a8bbba 100644
--- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
@@ -966,6 +966,7 @@ private[spark] object JsonProtocolSuite extends Assertions {
       |    "Getting Result Time": 0,
       |    "Finish Time": 0,
       |    "Failed": false,
+      |    "Killed": false,
       |    "Accumulables": [
       |      {
       |        "ID": 1,
@@ -1012,6 +1013,7 @@ private[spark] object JsonProtocolSuite extends 
Assertions {
       |    "Getting Result Time": 0,
       |    "Finish Time": 0,
       |    "Failed": false,
+      |    "Killed": false,
       |    "Accumulables": [
       |      {
       |        "ID": 1,
@@ -1064,6 +1066,7 @@ private[spark] object JsonProtocolSuite extends 
Assertions {
       |    "Getting Result Time": 0,
       |    "Finish Time": 0,
       |    "Failed": false,
+      |    "Killed": false,
       |    "Accumulables": [
       |      {
       |        "ID": 1,
@@ -1161,6 +1164,7 @@ private[spark] object JsonProtocolSuite extends 
Assertions {
       |    "Getting Result Time": 0,
       |    "Finish Time": 0,
       |    "Failed": false,
+      |    "Killed": false,
       |    "Accumulables": [
       |      {
       |        "ID": 1,
@@ -1258,6 +1262,7 @@ private[spark] object JsonProtocolSuite extends 
Assertions {
       |    "Getting Result Time": 0,
       |    "Finish Time": 0,
       |    "Failed": false,
+      |    "Killed": false,
       |    "Accumulables": [
       |      {
       |        "ID": 1,

http://git-wip-us.apache.org/repos/asf/spark/blob/5b21139d/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala 
b/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala
index c024b4e..1352ca1 100644
--- 
a/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala
+++ 
b/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala
@@ -97,6 +97,7 @@ private[ui] abstract class BatchTableBase(tableId: String, 
batchInterval: Long)
         completed = batch.numCompletedOutputOp,
         failed = batch.numFailedOutputOp,
         skipped = 0,
+        killed = 0,
         total = batch.outputOperations.size)
       }
     </td>

http://git-wip-us.apache.org/repos/asf/spark/blob/5b21139d/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala 
b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala
index 60122b4..1a87fc7 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala
@@ -146,6 +146,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends 
WebUIPage("batch") {
             completed = sparkJob.numCompletedTasks,
             failed = sparkJob.numFailedTasks,
             skipped = sparkJob.numSkippedTasks,
+            killed = sparkJob.numKilledTasks,
             total = sparkJob.numTasks - sparkJob.numSkippedTasks)
         }
       </td>


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to