This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.0 by this push:
new 1b221f3 [SPARK-31472][CORE][3.0] Make sure Barrier Task always return
messages or exception with abortableRpcFuture check
1b221f3 is described below
commit 1b221f35abd1657a3ecd49335118bfd5dcb811ee
Author: yi.wu <[email protected]>
AuthorDate: Thu Apr 23 14:43:27 2020 +0000
[SPARK-31472][CORE][3.0] Make sure Barrier Task always return messages or
exception with abortableRpcFuture check
### What changes were proposed in this pull request?
Rewrite the periodically check logic of `abortableRpcFuture` to make sure
that barrier task would always return either desired messages or expected
exception.
This PR also simplify a bit around `AbortableRpcFuture`.
### Why are the changes needed?
Currently, the periodically check logic of `abortableRpcFuture` is done by
following:
```scala
...
var messages: Array[String] = null
while (!abortableRpcFuture.toFuture.isCompleted) {
messages = ThreadUtils.awaitResult(abortableRpcFuture.toFuture, 1.second)
...
}
return messages
```
It's possible that `abortableRpcFuture` complete before next invocation on
`messages = ...`. In this case, the task may return null messages or execute
successfully while it should throw exception(e.g. `SparkException` from
`BarrierCoordinator`).
And here's a flaky test which caused by this bug:
```
[info] BarrierTaskContextSuite:
[info] - share messages with allGather() call *** FAILED *** (18 seconds,
705 milliseconds)
[info] org.apache.spark.SparkException: Job aborted due to stage failure:
Could not recover from a failed barrier ResultStage. Most recent failure
reason: Stage failed because barrier task ResultTask(0, 2) finished
unsuccessfully.
[info] java.lang.NullPointerException
[info] at
scala.collection.mutable.ArrayOps$ofRef$.length$extension(ArrayOps.scala:204)
[info] at
scala.collection.mutable.ArrayOps$ofRef.length(ArrayOps.scala:204)
[info] at
scala.collection.IndexedSeqOptimized.toList(IndexedSeqOptimized.scala:285)
[info] at
scala.collection.IndexedSeqOptimized.toList$(IndexedSeqOptimized.scala:284)
[info] at
scala.collection.mutable.ArrayOps$ofRef.toList(ArrayOps.scala:198)
[info] at
org.apache.spark.scheduler.BarrierTaskContextSuite.$anonfun$new$4(BarrierTaskContextSuite.scala:68)
...
```
The test exception can be reproduced by changing the line `messages = ...`
to the following:
```scala
messages = ThreadUtils.awaitResult(abortableRpcFuture.toFuture, 10.micros)
Thread.sleep(5000)
```
### Does this PR introduce any user-facing change?
No.
### How was this patch tested?
Manually test and update some unit tests.
Closes #28312 from Ngone51/cherry-pick-31472.
Authored-by: yi.wu <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../org/apache/spark/BarrierTaskContext.scala | 30 ++++++++++------------
.../org/apache/spark/rpc/RpcEndpointRef.scala | 10 +++-----
.../org/apache/spark/rpc/netty/NettyRpcEnv.scala | 12 +++++----
.../scala/org/apache/spark/util/ThreadUtils.scala | 5 ++--
.../scala/org/apache/spark/rpc/RpcEnvSuite.scala | 12 ++++-----
5 files changed, 32 insertions(+), 37 deletions(-)
diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala
b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala
index 06f8024..4d76548 100644
--- a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala
@@ -20,9 +20,9 @@ package org.apache.spark
import java.util.{Properties, Timer, TimerTask}
import scala.collection.JavaConverters._
-import scala.concurrent.TimeoutException
import scala.concurrent.duration._
import scala.language.postfixOps
+import scala.util.{Failure, Success => ScalaSuccess, Try}
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.executor.TaskMetrics
@@ -85,28 +85,26 @@ class BarrierTaskContext private[spark] (
// BarrierCoordinator on timeout, instead of RPCTimeoutException from
the RPC framework.
timeout = new RpcTimeout(365.days, "barrierTimeout"))
- // messages which consist of all barrier tasks' messages
- var messages: Array[String] = null
// Wait the RPC future to be completed, but every 1 second it will jump
out waiting
// and check whether current spark task is killed. If killed, then throw
// a `TaskKilledException`, otherwise continue wait RPC until it
completes.
- try {
- while (!abortableRpcFuture.toFuture.isCompleted) {
+
+ while (!abortableRpcFuture.future.isCompleted) {
+ try {
// wait RPC future for at most 1 second
- try {
- messages = ThreadUtils.awaitResult(abortableRpcFuture.toFuture,
1.second)
- } catch {
- case _: TimeoutException | _: InterruptedException =>
- // If `TimeoutException` thrown, waiting RPC future reach 1
second.
- // If `InterruptedException` thrown, it is possible this task is
killed.
- // So in this two cases, we should check whether task is killed
and then
- // throw `TaskKilledException`
- taskContext.killTaskIfInterrupted()
+ Thread.sleep(1000)
+ } catch {
+ case _: InterruptedException => // task is killed by driver
+ } finally {
+ Try(taskContext.killTaskIfInterrupted()) match {
+ case ScalaSuccess(_) => // task is still running healthily
+ case Failure(e) => abortableRpcFuture.abort(e)
}
}
- } finally {
-
abortableRpcFuture.abort(taskContext.getKillReason().getOrElse("Unknown
reason."))
}
+ // messages which consist of all barrier tasks' messages. The future
will return the
+ // desired messages if it is completed successfully. Otherwise,
exception could be thrown.
+ val messages = abortableRpcFuture.future.value.get.get
barrierEpoch += 1
logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt
$stageAttemptNumber) finished " +
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala
b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala
index 56f3d37..a3d27b0 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala
@@ -114,11 +114,7 @@ private[spark] class RpcAbortException(message: String)
extends Exception(messag
* A wrapper for [[Future]] but add abort method.
* This is used in long run RPC and provide an approach to abort the RPC.
*/
-private[spark] class AbortableRpcFuture[T: ClassTag](
- future: Future[T],
- onAbort: String => Unit) {
-
- def abort(reason: String): Unit = onAbort(reason)
-
- def toFuture: Future[T] = future
+private[spark]
+class AbortableRpcFuture[T: ClassTag](val future: Future[T], onAbort:
Throwable => Unit) {
+ def abort(t: Throwable): Unit = onAbort(t)
}
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
index 265e158..9259ec7 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
@@ -208,6 +208,7 @@ private[netty] class NettyRpcEnv(
message: RequestMessage, timeout: RpcTimeout): AbortableRpcFuture[T] = {
val promise = Promise[Any]()
val remoteAddr = message.receiver.address
+ var rpcMsg: Option[RpcOutboxMessage] = None
def onFailure(e: Throwable): Unit = {
if (!promise.tryFailure(e)) {
@@ -226,8 +227,9 @@ private[netty] class NettyRpcEnv(
}
}
- def onAbort(reason: String): Unit = {
- onFailure(new RpcAbortException(reason))
+ def onAbort(t: Throwable): Unit = {
+ onFailure(t)
+ rpcMsg.foreach(_.onAbort())
}
try {
@@ -242,10 +244,10 @@ private[netty] class NettyRpcEnv(
val rpcMessage = RpcOutboxMessage(message.serialize(this),
onFailure,
(client, response) => onSuccess(deserialize[Any](client, response)))
+ rpcMsg = Option(rpcMessage)
postToOutbox(message.receiver, rpcMessage)
promise.future.failed.foreach {
case _: TimeoutException => rpcMessage.onTimeout()
- case _: RpcAbortException => rpcMessage.onAbort()
case _ =>
}(ThreadUtils.sameThread)
}
@@ -270,7 +272,7 @@ private[netty] class NettyRpcEnv(
}
private[netty] def ask[T: ClassTag](message: RequestMessage, timeout:
RpcTimeout): Future[T] = {
- askAbortable(message, timeout).toFuture
+ askAbortable(message, timeout).future
}
private[netty] def serialize(content: Any): ByteBuffer = {
@@ -547,7 +549,7 @@ private[netty] class NettyRpcEndpointRef(
}
override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T]
= {
- askAbortable(message, timeout).toFuture
+ askAbortable(message, timeout).future
}
override def send(message: Any): Unit = {
diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
index e7872bb..78206c5 100644
--- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
@@ -29,7 +29,6 @@ import scala.util.control.NonFatal
import com.google.common.util.concurrent.ThreadFactoryBuilder
import org.apache.spark.SparkException
-import org.apache.spark.rpc.RpcAbortException
private[spark] object ThreadUtils {
@@ -299,7 +298,7 @@ private[spark] object ThreadUtils {
// TimeoutException and RpcAbortException is thrown in the current
thread, so not need to warp
// the exception.
case NonFatal(t)
- if !t.isInstanceOf[TimeoutException] &&
!t.isInstanceOf[RpcAbortException] =>
+ if !t.isInstanceOf[TimeoutException] =>
throw new SparkException("Exception thrown in awaitResult: ", t)
}
}
@@ -316,7 +315,7 @@ private[spark] object ThreadUtils {
case e: SparkFatalException =>
throw e.throwable
case NonFatal(t)
- if !t.isInstanceOf[TimeoutException] &&
!t.isInstanceOf[RpcAbortException] =>
+ if !t.isInstanceOf[TimeoutException] =>
throw new SparkException("Exception thrown in awaitResult: ", t)
}
}
diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
index c10f2c2..01c67b3 100644
--- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
@@ -209,7 +209,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with
BeforeAndAfterAll {
// Use anotherEnv to find out the RpcEndpointRef
val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address, "ask-abort")
try {
- val e = intercept[RpcAbortException] {
+ val e = intercept[SparkException] {
val timeout = new RpcTimeout(10.seconds, shortProp)
val abortableRpcFuture = rpcEndpointRef.askAbortable[String](
"hello", timeout)
@@ -217,15 +217,15 @@ abstract class RpcEnvSuite extends SparkFunSuite with
BeforeAndAfterAll {
new Thread {
override def run: Unit = {
Thread.sleep(100)
- abortableRpcFuture.abort("TestAbort")
+ abortableRpcFuture.abort(new RuntimeException("TestAbort"))
}
}.start()
- timeout.awaitResult(abortableRpcFuture.toFuture)
+ timeout.awaitResult(abortableRpcFuture.future)
}
- // The SparkException cause should be a RpcAbortException with
"TestAbort" message
- assert(e.isInstanceOf[RpcAbortException])
- assert(e.getMessage.contains("TestAbort"))
+ // The SparkException cause should be a RuntimeException with
"TestAbort" message
+ assert(e.getCause.isInstanceOf[RuntimeException])
+ assert(e.getCause.getMessage.contains("TestAbort"))
} finally {
anotherEnv.shutdown()
anotherEnv.awaitTermination()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]