This is an automated email from the ASF dual-hosted git repository.
cutlerb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new c277afb [SPARK-27992][PYTHON] Allow Python to join with connection
thread to propagate errors
c277afb is described below
commit c277afb12b61a91272568dd46380c0d0a9958989
Author: Bryan Cutler <[email protected]>
AuthorDate: Wed Jun 26 13:05:41 2019 -0700
[SPARK-27992][PYTHON] Allow Python to join with connection thread to
propagate errors
## What changes were proposed in this pull request?
Currently with `toLocalIterator()` and `toPandas()` with Arrow enabled, if
the Spark job being run in the background serving thread errors, it will be
caught and sent to Python through the PySpark serializer.
This is not the ideal solution because it is only catch a SparkException,
it won't handle an error that occurs in the serializer, and each method has to
have it's own special handling to propagate the error.
This PR instead returns the Python Server object along with the serving
port and authentication info, so that it allows the Python caller to join with
the serving thread. During the call to join, the serving thread Future is
completed either successfully or with an exception. In the latter case, the
exception will be propagated to Python through the Py4j call.
## How was this patch tested?
Existing tests
Closes #24834 from BryanCutler/pyspark-propagate-server-error-SPARK-27992.
Authored-by: Bryan Cutler <[email protected]>
Signed-off-by: Bryan Cutler <[email protected]>
---
.../org/apache/spark/api/python/PythonRDD.scala | 90 +++++++++++----------
.../main/scala/org/apache/spark/api/r/RRDD.scala | 4 +-
.../apache/spark/security/SocketAuthHelper.scala | 19 +----
.../apache/spark/security/SocketAuthServer.scala | 94 +++++++++++++++-------
.../main/scala/org/apache/spark/util/Utils.scala | 4 +-
python/pyspark/rdd.py | 26 ++++--
python/pyspark/sql/dataframe.py | 10 ++-
python/pyspark/sql/tests/test_arrow.py | 2 +-
.../main/scala/org/apache/spark/sql/Dataset.scala | 38 ++++-----
9 files changed, 161 insertions(+), 126 deletions(-)
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index fe25c3a..5b80e14 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -38,7 +38,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.BUFFER_SIZE
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.rdd.RDD
-import org.apache.spark.security.{SocketAuthHelper, SocketAuthServer}
+import org.apache.spark.security.{SocketAuthHelper, SocketAuthServer,
SocketFuncServer}
import org.apache.spark.util._
@@ -137,8 +137,9 @@ private[spark] object PythonRDD extends Logging {
* (effectively a collect()), but allows you to run on a certain subset of
partitions,
* or to enable local execution.
*
- * @return 2-tuple (as a Java array) with the port number of a local socket
which serves the
- * data collected from this job, and the secret for authentication.
+ * @return 3-tuple (as a Java array) with the port number of a local socket
which serves the
+ * data collected from this job, the secret for authentication, and
a socket auth
+ * server object that can be used to join the JVM serving thread in
Python.
*/
def runJob(
sc: SparkContext,
@@ -156,8 +157,9 @@ private[spark] object PythonRDD extends Logging {
/**
* A helper function to collect an RDD as an iterator, then serve it via
socket.
*
- * @return 2-tuple (as a Java array) with the port number of a local socket
which serves the
- * data collected from this job, and the secret for authentication.
+ * @return 3-tuple (as a Java array) with the port number of a local socket
which serves the
+ * data collected from this job, the secret for authentication, and
a socket auth
+ * server object that can be used to join the JVM serving thread in
Python.
*/
def collectAndServe[T](rdd: RDD[T]): Array[Any] = {
serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}")
@@ -168,58 +170,59 @@ private[spark] object PythonRDD extends Logging {
* are collected as separate jobs, by order of index. Partition data is
first requested by a
* non-zero integer to start a collection job. The response is prefaced by
an integer with 1
* meaning partition data will be served, 0 meaning the local iterator has
been consumed,
- * and -1 meaining an error occurred during collection. This function is
used by
+ * and -1 meaning an error occurred during collection. This function is used
by
* pyspark.rdd._local_iterator_from_socket().
*
- * @return 2-tuple (as a Java array) with the port number of a local socket
which serves the
- * data collected from these jobs, and the secret for authentication.
+ * @return 3-tuple (as a Java array) with the port number of a local socket
which serves the
+ * data collected from this job, the secret for authentication, and
a socket auth
+ * server object that can be used to join the JVM serving thread in
Python.
*/
def toLocalIteratorAndServe[T](rdd: RDD[T]): Array[Any] = {
- val (port, secret) = SocketAuthServer.setupOneConnectionServer(
- authHelper, "serve toLocalIterator") { s =>
- val out = new DataOutputStream(s.getOutputStream)
- val in = new DataInputStream(s.getInputStream)
- Utils.tryWithSafeFinally {
-
+ val handleFunc = (sock: Socket) => {
+ val out = new DataOutputStream(sock.getOutputStream)
+ val in = new DataInputStream(sock.getInputStream)
+ Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
// Collects a partition on each iteration
val collectPartitionIter = rdd.partitions.indices.iterator.map { i =>
rdd.sparkContext.runJob(rdd, (iter: Iterator[Any]) => iter.toArray,
Seq(i)).head
}
- // Read request for data and send next partition if nonzero
+ // Write data until iteration is complete, client stops iteration, or
error occurs
var complete = false
- while (!complete && in.readInt() != 0) {
- if (collectPartitionIter.hasNext) {
- try {
- // Attempt to collect the next partition
- val partitionArray = collectPartitionIter.next()
-
- // Send response there is a partition to read
- out.writeInt(1)
-
- // Write the next object and signal end of data for this
iteration
- writeIteratorToStream(partitionArray.toIterator, out)
- out.writeInt(SpecialLengths.END_OF_DATA_SECTION)
- out.flush()
- } catch {
- case e: SparkException =>
- // Send response that an error occurred followed by error
message
- out.writeInt(-1)
- writeUTF(e.getMessage, out)
- complete = true
- }
+ while (!complete) {
+
+ // Read request for data, value of zero will stop iteration or
non-zero to continue
+ if (in.readInt() == 0) {
+ complete = true
+ } else if (collectPartitionIter.hasNext) {
+
+ // Client requested more data, attempt to collect the next
partition
+ val partitionArray = collectPartitionIter.next()
+
+ // Send response there is a partition to read
+ out.writeInt(1)
+
+ // Write the next object and signal end of data for this iteration
+ writeIteratorToStream(partitionArray.toIterator, out)
+ out.writeInt(SpecialLengths.END_OF_DATA_SECTION)
+ out.flush()
} else {
// Send response there are no more partitions to read and close
out.writeInt(0)
complete = true
}
}
- } {
+ })(catchBlock = {
+ // Send response that an error occurred, original exception is
re-thrown
+ out.writeInt(-1)
+ }, finallyBlock = {
out.close()
in.close()
- }
+ })
}
- Array(port, secret)
+
+ val server = new SocketFuncServer(authHelper, "serve toLocalIterator",
handleFunc)
+ Array(server.port, server.secret, server)
}
def readRDDFromFile(
@@ -443,8 +446,9 @@ private[spark] object PythonRDD extends Logging {
*
* The thread will terminate after all the data are sent or any exceptions
happen.
*
- * @return 2-tuple (as a Java array) with the port number of a local socket
which serves the
- * data collected from this job, and the secret for authentication.
+ * @return 3-tuple (as a Java array) with the port number of a local socket
which serves the
+ * data collected from this job, the secret for authentication, and
a socket auth
+ * server object that can be used to join the JVM serving thread in
Python.
*/
def serveIterator(items: Iterator[_], threadName: String): Array[Any] = {
serveToStream(threadName) { out =>
@@ -464,10 +468,14 @@ private[spark] object PythonRDD extends Logging {
*
* The thread will terminate after the block of code is executed or any
* exceptions happen.
+ *
+ * @return 3-tuple (as a Java array) with the port number of a local socket
which serves the
+ * data collected from this job, the secret for authentication, and
a socket auth
+ * server object that can be used to join the JVM serving thread in
Python.
*/
private[spark] def serveToStream(
threadName: String)(writeFunc: OutputStream => Unit): Array[Any] = {
- SocketAuthHelper.serveToStream(threadName, authHelper)(writeFunc)
+ SocketAuthServer.serveToStream(threadName, authHelper)(writeFunc)
}
private def getMergedConf(confAsMap: java.util.HashMap[String, String],
diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
index 07f8405..892e69b 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
@@ -29,7 +29,7 @@ import org.apache.spark.api.java.{JavaPairRDD, JavaRDD,
JavaSparkContext}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
-import org.apache.spark.security.{SocketAuthHelper, SocketAuthServer}
+import org.apache.spark.security.SocketAuthServer
private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
parent: RDD[T],
@@ -166,7 +166,7 @@ private[spark] object RRDD {
private[spark] def serveToStream(
threadName: String)(writeFunc: OutputStream => Unit): Array[Any] = {
- SocketAuthHelper.serveToStream(threadName, new
RAuthHelper(SparkEnv.get.conf))(writeFunc)
+ SocketAuthServer.serveToStream(threadName, new
RAuthHelper(SparkEnv.get.conf))(writeFunc)
}
}
diff --git
a/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala
b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala
index 3a107c0..dbcb376 100644
--- a/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala
+++ b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala
@@ -17,7 +17,7 @@
package org.apache.spark.security
-import java.io.{BufferedOutputStream, DataInputStream, DataOutputStream,
OutputStream}
+import java.io.{DataInputStream, DataOutputStream}
import java.net.Socket
import java.nio.charset.StandardCharsets.UTF_8
@@ -113,21 +113,4 @@ private[spark] class SocketAuthHelper(conf: SparkConf) {
dout.write(bytes, 0, bytes.length)
dout.flush()
}
-
-}
-
-private[spark] object SocketAuthHelper {
- def serveToStream(
- threadName: String,
- authHelper: SocketAuthHelper)(writeFunc: OutputStream => Unit):
Array[Any] = {
- val (port, secret) = SocketAuthServer.setupOneConnectionServer(authHelper,
threadName) { s =>
- val out = new BufferedOutputStream(s.getOutputStream())
- Utils.tryWithSafeFinally {
- writeFunc(out)
- } {
- out.close()
- }
- }
- Array(port, secret)
- }
}
diff --git
a/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala
b/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala
index e616d23..548fd1b 100644
--- a/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala
+++ b/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala
@@ -17,6 +17,7 @@
package org.apache.spark.security
+import java.io.{BufferedOutputStream, OutputStream}
import java.net.{InetAddress, ServerSocket, Socket}
import scala.concurrent.Promise
@@ -25,12 +26,15 @@ import scala.util.Try
import org.apache.spark.SparkEnv
import org.apache.spark.network.util.JavaUtils
-import org.apache.spark.util.ThreadUtils
+import org.apache.spark.util.{ThreadUtils, Utils}
/**
* Creates a server in the JVM to communicate with external processes (e.g.,
Python and R) for
* handling one batch of data, with authentication and error handling.
+ *
+ * The socket server can only accept one connection, or close if no connection
+ * in 15 seconds.
*/
private[spark] abstract class SocketAuthServer[T](
authHelper: SocketAuthHelper,
@@ -41,10 +45,30 @@ private[spark] abstract class SocketAuthServer[T](
private val promise = Promise[T]()
- val (port, secret) = SocketAuthServer.setupOneConnectionServer(authHelper,
threadName) { sock =>
- promise.complete(Try(handleConnection(sock)))
+ private def startServer(): (Int, String) = {
+ val serverSocket = new ServerSocket(0, 1,
InetAddress.getByAddress(Array(127, 0, 0, 1)))
+ // Close the socket if no connection in 15 seconds
+ serverSocket.setSoTimeout(15000)
+
+ new Thread(threadName) {
+ setDaemon(true)
+ override def run(): Unit = {
+ var sock: Socket = null
+ try {
+ sock = serverSocket.accept()
+ authHelper.authClient(sock)
+ promise.complete(Try(handleConnection(sock)))
+ } finally {
+ JavaUtils.closeQuietly(serverSocket)
+ JavaUtils.closeQuietly(sock)
+ }
+ }
+ }.start()
+ (serverSocket.getLocalPort, authHelper.secret)
}
+ val (port, secret) = startServer()
+
/**
* Handle a connection which has already been authenticated. Any error from
this function
* will clean up this connection and the entire server, and get propagated
to [[getResult]].
@@ -66,42 +90,50 @@ private[spark] abstract class SocketAuthServer[T](
}
+/**
+ * Create a socket server class and run user function on the socket in a
background thread
+ * that can read and write to the socket input/output streams. The function is
passed in a
+ * socket that has been connected and authenticated.
+ */
+private[spark] class SocketFuncServer(
+ authHelper: SocketAuthHelper,
+ threadName: String,
+ func: Socket => Unit) extends SocketAuthServer[Unit](authHelper,
threadName) {
+
+ override def handleConnection(sock: Socket): Unit = {
+ func(sock)
+ }
+}
+
private[spark] object SocketAuthServer {
/**
- * Create a socket server and run user function on the socket in a
background thread.
+ * Convenience function to create a socket server and run a user function in
a background
+ * thread to write to an output stream.
*
* The socket server can only accept one connection, or close if no
connection
* in 15 seconds.
*
- * The thread will terminate after the supplied user function, or if there
are any exceptions.
- *
- * If you need to get a result of the supplied function, create a subclass
of [[SocketAuthServer]]
- *
- * @return The port number of a local socket and the secret for
authentication.
+ * @param threadName Name for the background serving thread.
+ * @param authHelper SocketAuthHelper for authentication
+ * @param writeFunc User function to write to a given OutputStream
+ * @return 3-tuple (as a Java array) with the port number of a local socket
which serves the
+ * data collected from this job, the secret for authentication, and
a socket auth
+ * server object that can be used to join the JVM serving thread in
Python.
*/
- def setupOneConnectionServer(
- authHelper: SocketAuthHelper,
- threadName: String)
- (func: Socket => Unit): (Int, String) = {
- val serverSocket = new ServerSocket(0, 1,
InetAddress.getByAddress(Array(127, 0, 0, 1)))
- // Close the socket if no connection in 15 seconds
- serverSocket.setSoTimeout(15000)
-
- new Thread(threadName) {
- setDaemon(true)
- override def run(): Unit = {
- var sock: Socket = null
- try {
- sock = serverSocket.accept()
- authHelper.authClient(sock)
- func(sock)
- } finally {
- JavaUtils.closeQuietly(serverSocket)
- JavaUtils.closeQuietly(sock)
- }
+ def serveToStream(
+ threadName: String,
+ authHelper: SocketAuthHelper)(writeFunc: OutputStream => Unit):
Array[Any] = {
+ val handleFunc = (sock: Socket) => {
+ val out = new BufferedOutputStream(sock.getOutputStream())
+ Utils.tryWithSafeFinally {
+ writeFunc(out)
+ } {
+ out.close()
}
- }.start()
- (serverSocket.getLocalPort, authHelper.secret)
+ }
+
+ val server = new SocketFuncServer(authHelper, threadName, handleFunc)
+ Array(server.port, server.secret, server)
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala
b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 00135c3..80d70a1 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -1389,7 +1389,9 @@ private[spark] object Utils extends Logging {
originalThrowable = cause
try {
logError("Aborting task", originalThrowable)
- TaskContext.get().markTaskFailed(originalThrowable)
+ if (TaskContext.get() != null) {
+ TaskContext.get().markTaskFailed(originalThrowable)
+ }
catchBlock
} catch {
case t: Throwable =>
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 395abc8..fa4609d 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -140,7 +140,15 @@ def _parse_memory(s):
def _create_local_socket(sock_info):
- (sockfile, sock) = local_connect_and_auth(*sock_info)
+ """
+ Create a local socket that can be used to load deserialized data from the
JVM
+
+ :param sock_info: Tuple containing port number and authentication secret
for a local socket.
+ :return: sockfile file descriptor of the local socket
+ """
+ port = sock_info[0]
+ auth_secret = sock_info[1]
+ sockfile, sock = local_connect_and_auth(port, auth_secret)
# The RDD materialization time is unpredictable, if we set a timeout for
socket reading
# operation, it will very possibly fail. See SPARK-18281.
sock.settimeout(None)
@@ -148,6 +156,13 @@ def _create_local_socket(sock_info):
def _load_from_socket(sock_info, serializer):
+ """
+ Connect to a local socket described by sock_info and use the given
serializer to yield data
+
+ :param sock_info: Tuple containing port number and authentication secret
for a local socket.
+ :param serializer: The PySpark serializer to use
+ :return: result of Serializer.load_stream, usually a generator that yields
deserialized data
+ """
sockfile = _create_local_socket(sock_info)
# The socket will be automatically closed when garbage-collected.
return serializer.load_stream(sockfile)
@@ -159,7 +174,8 @@ def _local_iterator_from_socket(sock_info, serializer):
""" Create a synchronous local iterable over a socket """
def __init__(self, _sock_info, _serializer):
- self._sockfile = _create_local_socket(_sock_info)
+ port, auth_secret, self.jsocket_auth_server = _sock_info
+ self._sockfile = _create_local_socket((port, auth_secret))
self._serializer = _serializer
self._read_iter = iter([]) # Initialize as empty iterator
self._read_status = 1
@@ -179,11 +195,9 @@ def _local_iterator_from_socket(sock_info, serializer):
for item in self._read_iter:
yield item
- # An error occurred, read error message and raise it
+ # An error occurred, join serving thread and raise any
exceptions from the JVM
elif self._read_status == -1:
- error_msg = UTF8Deserializer().loads(self._sockfile)
- raise RuntimeError("An error occurred while reading the
next element from "
- "toLocalIterator: {}".format(error_msg))
+ self.jsocket_auth_server.getResult()
def __del__(self):
# If local iterator is not fully consumed,
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 6ba740d..8b0e06d 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -2200,10 +2200,16 @@ class DataFrame(object):
.. note:: Experimental.
"""
with SCCallSiteSync(self._sc) as css:
- sock_info = self._jdf.collectAsArrowToPython()
+ port, auth_secret, jsocket_auth_server =
self._jdf.collectAsArrowToPython()
# Collect list of un-ordered batches where last element is a list of
correct order indices
- results = list(_load_from_socket(sock_info, ArrowCollectSerializer()))
+ try:
+ results = list(_load_from_socket((port, auth_secret),
ArrowCollectSerializer()))
+ finally:
+ # Join serving thread and raise any exceptions from
collectAsArrowToPython
+ jsocket_auth_server.getResult()
+
+ # Separate RecordBatches from batch order indices in results
batches = results[:-1]
batch_order = results[-1]
diff --git a/python/pyspark/sql/tests/test_arrow.py
b/python/pyspark/sql/tests/test_arrow.py
index 839ff888..1f96d2c 100644
--- a/python/pyspark/sql/tests/test_arrow.py
+++ b/python/pyspark/sql/tests/test_arrow.py
@@ -214,7 +214,7 @@ class ArrowTests(ReusedSQLTestCase):
exception_udf = udf(raise_exception, IntegerType())
df = df.withColumn("error", exception_udf())
with QuietTest(self.sc):
- with self.assertRaisesRegexp(RuntimeError, 'My error'):
+ with self.assertRaisesRegexp(Exception, 'My error'):
df.toPandas()
def _createDataFrame_toggle(self, pdf, schema=None):
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index a80aade..45ec7dc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -26,7 +26,7 @@ import scala.util.control.NonFatal
import org.apache.commons.lang3.StringUtils
-import org.apache.spark.{SparkException, TaskContext}
+import org.apache.spark.TaskContext
import org.apache.spark.annotation.{DeveloperApi, Evolving, Experimental,
Stable, Unstable}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.java.function._
@@ -3321,34 +3321,24 @@ class Dataset[T] private[sql](
}
}
- var sparkException: Option[SparkException] = None
- try {
+ Utils.tryWithSafeFinally {
val arrowBatchRdd = toArrowBatchRdd(plan)
sparkSession.sparkContext.runJob(
arrowBatchRdd,
(it: Iterator[Array[Byte]]) => it.toArray,
handlePartitionBatches)
- } catch {
- case e: SparkException =>
- sparkException = Some(e)
- }
-
- // After processing all partitions, end the batch stream
- batchWriter.end()
- sparkException match {
- case Some(exception) =>
- // Signal failure and write error message
- out.writeInt(-1)
- PythonRDD.writeUTF(exception.getMessage, out)
- case None =>
- // Write batch order indices
- out.writeInt(batchOrder.length)
- // Sort by (index of partition, batch index in that partition)
tuple to get the
- // overall_batch_index from 0 to N-1 batches, which can be used to
put the
- // transferred batches in the correct order
- batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_,
overallBatchIndex) =>
- out.writeInt(overallBatchIndex)
- }
+ } {
+ // After processing all partitions, end the batch stream
+ batchWriter.end()
+
+ // Write batch order indices
+ out.writeInt(batchOrder.length)
+ // Sort by (index of partition, batch index in that partition) tuple
to get the
+ // overall_batch_index from 0 to N-1 batches, which can be used to
put the
+ // transferred batches in the correct order
+ batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_,
overallBatchIndex) =>
+ out.writeInt(overallBatchIndex)
+ }
}
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]