This is an automated email from the ASF dual-hosted git repository.

cutlerb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new c277afb  [SPARK-27992][PYTHON] Allow Python to join with connection 
thread to propagate errors
c277afb is described below

commit c277afb12b61a91272568dd46380c0d0a9958989
Author: Bryan Cutler <[email protected]>
AuthorDate: Wed Jun 26 13:05:41 2019 -0700

    [SPARK-27992][PYTHON] Allow Python to join with connection thread to 
propagate errors
    
    ## What changes were proposed in this pull request?
    
    Currently with `toLocalIterator()` and `toPandas()` with Arrow enabled, if 
the Spark job being run in the background serving thread errors, it will be 
caught and sent to Python through the PySpark serializer.
    This is not the ideal solution because it is only catch a SparkException, 
it won't handle an error that occurs in the serializer, and each method has to 
have it's own special handling to propagate the error.
    
    This PR instead returns the Python Server object along with the serving 
port and authentication info, so that it allows the Python caller to join with 
the serving thread. During the call to join, the serving thread Future is 
completed either successfully or with an exception. In the latter case, the 
exception will be propagated to Python through the Py4j call.
    
    ## How was this patch tested?
    
    Existing tests
    
    Closes #24834 from BryanCutler/pyspark-propagate-server-error-SPARK-27992.
    
    Authored-by: Bryan Cutler <[email protected]>
    Signed-off-by: Bryan Cutler <[email protected]>
---
 .../org/apache/spark/api/python/PythonRDD.scala    | 90 +++++++++++----------
 .../main/scala/org/apache/spark/api/r/RRDD.scala   |  4 +-
 .../apache/spark/security/SocketAuthHelper.scala   | 19 +----
 .../apache/spark/security/SocketAuthServer.scala   | 94 +++++++++++++++-------
 .../main/scala/org/apache/spark/util/Utils.scala   |  4 +-
 python/pyspark/rdd.py                              | 26 ++++--
 python/pyspark/sql/dataframe.py                    | 10 ++-
 python/pyspark/sql/tests/test_arrow.py             |  2 +-
 .../main/scala/org/apache/spark/sql/Dataset.scala  | 38 ++++-----
 9 files changed, 161 insertions(+), 126 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index fe25c3a..5b80e14 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -38,7 +38,7 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.internal.config.BUFFER_SIZE
 import org.apache.spark.network.util.JavaUtils
 import org.apache.spark.rdd.RDD
-import org.apache.spark.security.{SocketAuthHelper, SocketAuthServer}
+import org.apache.spark.security.{SocketAuthHelper, SocketAuthServer, 
SocketFuncServer}
 import org.apache.spark.util._
 
 
@@ -137,8 +137,9 @@ private[spark] object PythonRDD extends Logging {
    * (effectively a collect()), but allows you to run on a certain subset of 
partitions,
    * or to enable local execution.
    *
-   * @return 2-tuple (as a Java array) with the port number of a local socket 
which serves the
-   *         data collected from this job, and the secret for authentication.
+   * @return 3-tuple (as a Java array) with the port number of a local socket 
which serves the
+   *         data collected from this job, the secret for authentication, and 
a socket auth
+   *         server object that can be used to join the JVM serving thread in 
Python.
    */
   def runJob(
       sc: SparkContext,
@@ -156,8 +157,9 @@ private[spark] object PythonRDD extends Logging {
   /**
    * A helper function to collect an RDD as an iterator, then serve it via 
socket.
    *
-   * @return 2-tuple (as a Java array) with the port number of a local socket 
which serves the
-   *         data collected from this job, and the secret for authentication.
+   * @return 3-tuple (as a Java array) with the port number of a local socket 
which serves the
+   *         data collected from this job, the secret for authentication, and 
a socket auth
+   *         server object that can be used to join the JVM serving thread in 
Python.
    */
   def collectAndServe[T](rdd: RDD[T]): Array[Any] = {
     serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}")
@@ -168,58 +170,59 @@ private[spark] object PythonRDD extends Logging {
    * are collected as separate jobs, by order of index. Partition data is 
first requested by a
    * non-zero integer to start a collection job. The response is prefaced by 
an integer with 1
    * meaning partition data will be served, 0 meaning the local iterator has 
been consumed,
-   * and -1 meaining an error occurred during collection. This function is 
used by
+   * and -1 meaning an error occurred during collection. This function is used 
by
    * pyspark.rdd._local_iterator_from_socket().
    *
-   * @return 2-tuple (as a Java array) with the port number of a local socket 
which serves the
-   *         data collected from these jobs, and the secret for authentication.
+   * @return 3-tuple (as a Java array) with the port number of a local socket 
which serves the
+   *         data collected from this job, the secret for authentication, and 
a socket auth
+   *         server object that can be used to join the JVM serving thread in 
Python.
    */
   def toLocalIteratorAndServe[T](rdd: RDD[T]): Array[Any] = {
-    val (port, secret) = SocketAuthServer.setupOneConnectionServer(
-        authHelper, "serve toLocalIterator") { s =>
-      val out = new DataOutputStream(s.getOutputStream)
-      val in = new DataInputStream(s.getInputStream)
-      Utils.tryWithSafeFinally {
-
+    val handleFunc = (sock: Socket) => {
+      val out = new DataOutputStream(sock.getOutputStream)
+      val in = new DataInputStream(sock.getInputStream)
+      Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
         // Collects a partition on each iteration
         val collectPartitionIter = rdd.partitions.indices.iterator.map { i =>
           rdd.sparkContext.runJob(rdd, (iter: Iterator[Any]) => iter.toArray, 
Seq(i)).head
         }
 
-        // Read request for data and send next partition if nonzero
+        // Write data until iteration is complete, client stops iteration, or 
error occurs
         var complete = false
-        while (!complete && in.readInt() != 0) {
-          if (collectPartitionIter.hasNext) {
-            try {
-              // Attempt to collect the next partition
-              val partitionArray = collectPartitionIter.next()
-
-              // Send response there is a partition to read
-              out.writeInt(1)
-
-              // Write the next object and signal end of data for this 
iteration
-              writeIteratorToStream(partitionArray.toIterator, out)
-              out.writeInt(SpecialLengths.END_OF_DATA_SECTION)
-              out.flush()
-            } catch {
-              case e: SparkException =>
-                // Send response that an error occurred followed by error 
message
-                out.writeInt(-1)
-                writeUTF(e.getMessage, out)
-                complete = true
-            }
+        while (!complete) {
+
+          // Read request for data, value of zero will stop iteration or 
non-zero to continue
+          if (in.readInt() == 0) {
+            complete = true
+          } else if (collectPartitionIter.hasNext) {
+
+            // Client requested more data, attempt to collect the next 
partition
+            val partitionArray = collectPartitionIter.next()
+
+            // Send response there is a partition to read
+            out.writeInt(1)
+
+            // Write the next object and signal end of data for this iteration
+            writeIteratorToStream(partitionArray.toIterator, out)
+            out.writeInt(SpecialLengths.END_OF_DATA_SECTION)
+            out.flush()
           } else {
             // Send response there are no more partitions to read and close
             out.writeInt(0)
             complete = true
           }
         }
-      } {
+      })(catchBlock = {
+        // Send response that an error occurred, original exception is 
re-thrown
+        out.writeInt(-1)
+      }, finallyBlock = {
         out.close()
         in.close()
-      }
+      })
     }
-    Array(port, secret)
+
+    val server = new SocketFuncServer(authHelper, "serve toLocalIterator", 
handleFunc)
+    Array(server.port, server.secret, server)
   }
 
   def readRDDFromFile(
@@ -443,8 +446,9 @@ private[spark] object PythonRDD extends Logging {
    *
    * The thread will terminate after all the data are sent or any exceptions 
happen.
    *
-   * @return 2-tuple (as a Java array) with the port number of a local socket 
which serves the
-   *         data collected from this job, and the secret for authentication.
+   * @return 3-tuple (as a Java array) with the port number of a local socket 
which serves the
+   *         data collected from this job, the secret for authentication, and 
a socket auth
+   *         server object that can be used to join the JVM serving thread in 
Python.
    */
   def serveIterator(items: Iterator[_], threadName: String): Array[Any] = {
     serveToStream(threadName) { out =>
@@ -464,10 +468,14 @@ private[spark] object PythonRDD extends Logging {
    *
    * The thread will terminate after the block of code is executed or any
    * exceptions happen.
+   *
+   * @return 3-tuple (as a Java array) with the port number of a local socket 
which serves the
+   *         data collected from this job, the secret for authentication, and 
a socket auth
+   *         server object that can be used to join the JVM serving thread in 
Python.
    */
   private[spark] def serveToStream(
       threadName: String)(writeFunc: OutputStream => Unit): Array[Any] = {
-    SocketAuthHelper.serveToStream(threadName, authHelper)(writeFunc)
+    SocketAuthServer.serveToStream(threadName, authHelper)(writeFunc)
   }
 
   private def getMergedConf(confAsMap: java.util.HashMap[String, String],
diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala 
b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
index 07f8405..892e69b 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
@@ -29,7 +29,7 @@ import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, 
JavaSparkContext}
 import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.internal.Logging
 import org.apache.spark.rdd.RDD
-import org.apache.spark.security.{SocketAuthHelper, SocketAuthServer}
+import org.apache.spark.security.SocketAuthServer
 
 private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
     parent: RDD[T],
@@ -166,7 +166,7 @@ private[spark] object RRDD {
 
   private[spark] def serveToStream(
       threadName: String)(writeFunc: OutputStream => Unit): Array[Any] = {
-    SocketAuthHelper.serveToStream(threadName, new 
RAuthHelper(SparkEnv.get.conf))(writeFunc)
+    SocketAuthServer.serveToStream(threadName, new 
RAuthHelper(SparkEnv.get.conf))(writeFunc)
   }
 }
 
diff --git 
a/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala 
b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala
index 3a107c0..dbcb376 100644
--- a/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala
+++ b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.security
 
-import java.io.{BufferedOutputStream, DataInputStream, DataOutputStream, 
OutputStream}
+import java.io.{DataInputStream, DataOutputStream}
 import java.net.Socket
 import java.nio.charset.StandardCharsets.UTF_8
 
@@ -113,21 +113,4 @@ private[spark] class SocketAuthHelper(conf: SparkConf) {
     dout.write(bytes, 0, bytes.length)
     dout.flush()
   }
-
-}
-
-private[spark] object SocketAuthHelper {
-  def serveToStream(
-      threadName: String,
-      authHelper: SocketAuthHelper)(writeFunc: OutputStream => Unit): 
Array[Any] = {
-    val (port, secret) = SocketAuthServer.setupOneConnectionServer(authHelper, 
threadName) { s =>
-      val out = new BufferedOutputStream(s.getOutputStream())
-      Utils.tryWithSafeFinally {
-        writeFunc(out)
-      } {
-        out.close()
-      }
-    }
-    Array(port, secret)
-  }
 }
diff --git 
a/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala 
b/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala
index e616d23..548fd1b 100644
--- a/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala
+++ b/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.security
 
+import java.io.{BufferedOutputStream, OutputStream}
 import java.net.{InetAddress, ServerSocket, Socket}
 
 import scala.concurrent.Promise
@@ -25,12 +26,15 @@ import scala.util.Try
 
 import org.apache.spark.SparkEnv
 import org.apache.spark.network.util.JavaUtils
-import org.apache.spark.util.ThreadUtils
+import org.apache.spark.util.{ThreadUtils, Utils}
 
 
 /**
  * Creates a server in the JVM to communicate with external processes (e.g., 
Python and R) for
  * handling one batch of data, with authentication and error handling.
+ *
+ * The socket server can only accept one connection, or close if no connection
+ * in 15 seconds.
  */
 private[spark] abstract class SocketAuthServer[T](
     authHelper: SocketAuthHelper,
@@ -41,10 +45,30 @@ private[spark] abstract class SocketAuthServer[T](
 
   private val promise = Promise[T]()
 
-  val (port, secret) = SocketAuthServer.setupOneConnectionServer(authHelper, 
threadName) { sock =>
-    promise.complete(Try(handleConnection(sock)))
+  private def startServer(): (Int, String) = {
+    val serverSocket = new ServerSocket(0, 1, 
InetAddress.getByAddress(Array(127, 0, 0, 1)))
+    // Close the socket if no connection in 15 seconds
+    serverSocket.setSoTimeout(15000)
+
+    new Thread(threadName) {
+      setDaemon(true)
+      override def run(): Unit = {
+        var sock: Socket = null
+        try {
+          sock = serverSocket.accept()
+          authHelper.authClient(sock)
+          promise.complete(Try(handleConnection(sock)))
+        } finally {
+          JavaUtils.closeQuietly(serverSocket)
+          JavaUtils.closeQuietly(sock)
+        }
+      }
+    }.start()
+    (serverSocket.getLocalPort, authHelper.secret)
   }
 
+  val (port, secret) = startServer()
+
   /**
    * Handle a connection which has already been authenticated.  Any error from 
this function
    * will clean up this connection and the entire server, and get propagated 
to [[getResult]].
@@ -66,42 +90,50 @@ private[spark] abstract class SocketAuthServer[T](
 
 }
 
+/**
+ * Create a socket server class and run user function on the socket in a 
background thread
+ * that can read and write to the socket input/output streams. The function is 
passed in a
+ * socket that has been connected and authenticated.
+ */
+private[spark] class SocketFuncServer(
+    authHelper: SocketAuthHelper,
+    threadName: String,
+    func: Socket => Unit) extends SocketAuthServer[Unit](authHelper, 
threadName) {
+
+  override def handleConnection(sock: Socket): Unit = {
+    func(sock)
+  }
+}
+
 private[spark] object SocketAuthServer {
 
   /**
-   * Create a socket server and run user function on the socket in a 
background thread.
+   * Convenience function to create a socket server and run a user function in 
a background
+   * thread to write to an output stream.
    *
    * The socket server can only accept one connection, or close if no 
connection
    * in 15 seconds.
    *
-   * The thread will terminate after the supplied user function, or if there 
are any exceptions.
-   *
-   * If you need to get a result of the supplied function, create a subclass 
of [[SocketAuthServer]]
-   *
-   * @return The port number of a local socket and the secret for 
authentication.
+   * @param threadName Name for the background serving thread.
+   * @param authHelper SocketAuthHelper for authentication
+   * @param writeFunc User function to write to a given OutputStream
+   * @return 3-tuple (as a Java array) with the port number of a local socket 
which serves the
+   *         data collected from this job, the secret for authentication, and 
a socket auth
+   *         server object that can be used to join the JVM serving thread in 
Python.
    */
-  def setupOneConnectionServer(
-      authHelper: SocketAuthHelper,
-      threadName: String)
-      (func: Socket => Unit): (Int, String) = {
-    val serverSocket = new ServerSocket(0, 1, 
InetAddress.getByAddress(Array(127, 0, 0, 1)))
-    // Close the socket if no connection in 15 seconds
-    serverSocket.setSoTimeout(15000)
-
-    new Thread(threadName) {
-      setDaemon(true)
-      override def run(): Unit = {
-        var sock: Socket = null
-        try {
-          sock = serverSocket.accept()
-          authHelper.authClient(sock)
-          func(sock)
-        } finally {
-          JavaUtils.closeQuietly(serverSocket)
-          JavaUtils.closeQuietly(sock)
-        }
+  def serveToStream(
+      threadName: String,
+      authHelper: SocketAuthHelper)(writeFunc: OutputStream => Unit): 
Array[Any] = {
+    val handleFunc = (sock: Socket) => {
+      val out = new BufferedOutputStream(sock.getOutputStream())
+      Utils.tryWithSafeFinally {
+        writeFunc(out)
+      } {
+        out.close()
       }
-    }.start()
-    (serverSocket.getLocalPort, authHelper.secret)
+    }
+
+    val server = new SocketFuncServer(authHelper, threadName, handleFunc)
+    Array(server.port, server.secret, server)
   }
 }
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala 
b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 00135c3..80d70a1 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -1389,7 +1389,9 @@ private[spark] object Utils extends Logging {
         originalThrowable = cause
         try {
           logError("Aborting task", originalThrowable)
-          TaskContext.get().markTaskFailed(originalThrowable)
+          if (TaskContext.get() != null) {
+            TaskContext.get().markTaskFailed(originalThrowable)
+          }
           catchBlock
         } catch {
           case t: Throwable =>
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 395abc8..fa4609d 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -140,7 +140,15 @@ def _parse_memory(s):
 
 
 def _create_local_socket(sock_info):
-    (sockfile, sock) = local_connect_and_auth(*sock_info)
+    """
+    Create a local socket that can be used to load deserialized data from the 
JVM
+
+    :param sock_info: Tuple containing port number and authentication secret 
for a local socket.
+    :return: sockfile file descriptor of the local socket
+    """
+    port = sock_info[0]
+    auth_secret = sock_info[1]
+    sockfile, sock = local_connect_and_auth(port, auth_secret)
     # The RDD materialization time is unpredictable, if we set a timeout for 
socket reading
     # operation, it will very possibly fail. See SPARK-18281.
     sock.settimeout(None)
@@ -148,6 +156,13 @@ def _create_local_socket(sock_info):
 
 
 def _load_from_socket(sock_info, serializer):
+    """
+    Connect to a local socket described by sock_info and use the given 
serializer to yield data
+
+    :param sock_info: Tuple containing port number and authentication secret 
for a local socket.
+    :param serializer: The PySpark serializer to use
+    :return: result of Serializer.load_stream, usually a generator that yields 
deserialized data
+    """
     sockfile = _create_local_socket(sock_info)
     # The socket will be automatically closed when garbage-collected.
     return serializer.load_stream(sockfile)
@@ -159,7 +174,8 @@ def _local_iterator_from_socket(sock_info, serializer):
         """ Create a synchronous local iterable over a socket """
 
         def __init__(self, _sock_info, _serializer):
-            self._sockfile = _create_local_socket(_sock_info)
+            port, auth_secret, self.jsocket_auth_server = _sock_info
+            self._sockfile = _create_local_socket((port, auth_secret))
             self._serializer = _serializer
             self._read_iter = iter([])  # Initialize as empty iterator
             self._read_status = 1
@@ -179,11 +195,9 @@ def _local_iterator_from_socket(sock_info, serializer):
                     for item in self._read_iter:
                         yield item
 
-                # An error occurred, read error message and raise it
+                # An error occurred, join serving thread and raise any 
exceptions from the JVM
                 elif self._read_status == -1:
-                    error_msg = UTF8Deserializer().loads(self._sockfile)
-                    raise RuntimeError("An error occurred while reading the 
next element from "
-                                       "toLocalIterator: {}".format(error_msg))
+                    self.jsocket_auth_server.getResult()
 
         def __del__(self):
             # If local iterator is not fully consumed,
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 6ba740d..8b0e06d 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -2200,10 +2200,16 @@ class DataFrame(object):
         .. note:: Experimental.
         """
         with SCCallSiteSync(self._sc) as css:
-            sock_info = self._jdf.collectAsArrowToPython()
+            port, auth_secret, jsocket_auth_server = 
self._jdf.collectAsArrowToPython()
 
         # Collect list of un-ordered batches where last element is a list of 
correct order indices
-        results = list(_load_from_socket(sock_info, ArrowCollectSerializer()))
+        try:
+            results = list(_load_from_socket((port, auth_secret), 
ArrowCollectSerializer()))
+        finally:
+            # Join serving thread and raise any exceptions from 
collectAsArrowToPython
+            jsocket_auth_server.getResult()
+
+        # Separate RecordBatches from batch order indices in results
         batches = results[:-1]
         batch_order = results[-1]
 
diff --git a/python/pyspark/sql/tests/test_arrow.py 
b/python/pyspark/sql/tests/test_arrow.py
index 839ff888..1f96d2c 100644
--- a/python/pyspark/sql/tests/test_arrow.py
+++ b/python/pyspark/sql/tests/test_arrow.py
@@ -214,7 +214,7 @@ class ArrowTests(ReusedSQLTestCase):
         exception_udf = udf(raise_exception, IntegerType())
         df = df.withColumn("error", exception_udf())
         with QuietTest(self.sc):
-            with self.assertRaisesRegexp(RuntimeError, 'My error'):
+            with self.assertRaisesRegexp(Exception, 'My error'):
                 df.toPandas()
 
     def _createDataFrame_toggle(self, pdf, schema=None):
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index a80aade..45ec7dc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -26,7 +26,7 @@ import scala.util.control.NonFatal
 
 import org.apache.commons.lang3.StringUtils
 
-import org.apache.spark.{SparkException, TaskContext}
+import org.apache.spark.TaskContext
 import org.apache.spark.annotation.{DeveloperApi, Evolving, Experimental, 
Stable, Unstable}
 import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.api.java.function._
@@ -3321,34 +3321,24 @@ class Dataset[T] private[sql](
             }
           }
 
-        var sparkException: Option[SparkException] = None
-        try {
+        Utils.tryWithSafeFinally {
           val arrowBatchRdd = toArrowBatchRdd(plan)
           sparkSession.sparkContext.runJob(
             arrowBatchRdd,
             (it: Iterator[Array[Byte]]) => it.toArray,
             handlePartitionBatches)
-        } catch {
-          case e: SparkException =>
-            sparkException = Some(e)
-        }
-
-        // After processing all partitions, end the batch stream
-        batchWriter.end()
-        sparkException match {
-          case Some(exception) =>
-            // Signal failure and write error message
-            out.writeInt(-1)
-            PythonRDD.writeUTF(exception.getMessage, out)
-          case None =>
-            // Write batch order indices
-            out.writeInt(batchOrder.length)
-            // Sort by (index of partition, batch index in that partition) 
tuple to get the
-            // overall_batch_index from 0 to N-1 batches, which can be used to 
put the
-            // transferred batches in the correct order
-            batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, 
overallBatchIndex) =>
-              out.writeInt(overallBatchIndex)
-            }
+        } {
+          // After processing all partitions, end the batch stream
+          batchWriter.end()
+
+          // Write batch order indices
+          out.writeInt(batchOrder.length)
+          // Sort by (index of partition, batch index in that partition) tuple 
to get the
+          // overall_batch_index from 0 to N-1 batches, which can be used to 
put the
+          // transferred batches in the correct order
+          batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, 
overallBatchIndex) =>
+            out.writeInt(overallBatchIndex)
+          }
         }
       }
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to