This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new 5bfaa71d7bc [SPARK-42944][SS][PYTHON] Streaming ForeachBatch in Python 5bfaa71d7bc is described below commit 5bfaa71d7bc63a19c73bc0208eccf1d68dbf6ac7 Author: Raghu Angadi <raghu.ang...@databricks.com> AuthorDate: Wed Jul 19 09:04:05 2023 +0900 [SPARK-42944][SS][PYTHON] Streaming ForeachBatch in Python Adds `foreachBatch()` in Python. This adds a new runner `StreamingPythonRunner`. Note that this PR focuses on core functionality and includes TODO for followup improvements (will update with jira tickets where missing). Included more inline comments to help with the review. ### What changes were proposed in this pull request? Adds support for foreachBatch() in Spark connect. ### Why are the changes needed? - Manual tests - Unit tests: - The tests are updated to use a global temp view, rather than shared variable since connect version of the function runs on the server side. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? Closes #42035 from rangadi/feb-py. Authored-by: Raghu Angadi <raghu.ang...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> (cherry picked from commit d93f6d145142ec15a96b0a3bfbbaa044c4b725e9) Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../spark/sql/streaming/StreamingQuerySuite.scala | 2 +- .../sql/connect/planner/SparkConnectPlanner.scala | 5 +- .../planner/StreamingForeachBatchHelper.scala | 70 +++++++++++++++-- .../org/apache/spark/api/python/PythonRunner.scala | 1 + .../spark/api/python/PythonWorkerFactory.scala | 9 ++- .../spark/api/python/StreamingPythonRunner.scala | 88 ++++++++++++++++++++++ python/pyspark/sql/connect/session.py | 9 +++ python/pyspark/sql/connect/streaming/readwriter.py | 12 +-- .../test_parity_streaming_foreachBatch.py | 44 +++++++++++ .../tests/streaming/test_streaming_foreachBatch.py | 20 +++-- python/pyspark/streaming_worker.py | 78 +++++++++++++++++++ 11 files changed, 311 insertions(+), 27 deletions(-) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 1287176d76e..91d744b9e48 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -335,7 +335,7 @@ class StreamingQuerySuite extends QueryTest with SQLHelper with Logging { .start() eventually(timeout(30.seconds)) { // Wait for first progress. - assert(q.lastProgress != null) + assert(q.lastProgress != null, "Failed to make progress") assert(q.lastProgress.numInputRows > 0) } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 39cb4c1b972..92a9524f67a 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -2792,7 +2792,8 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { if (writeOp.hasForeachBatch) { val foreachBatchFn = writeOp.getForeachBatch.getFunctionCase match { case StreamingForeachFunction.FunctionCase.PYTHON_FUNCTION => - throw InvalidPlanInput("Python ForeachBatch is not supported yet. WIP.") + val pythonFn = transformPythonFunction(writeOp.getForeachBatch.getPythonFunction) + StreamingForeachBatchHelper.pythonForeachBatchWrapper(pythonFn, sessionHolder) case StreamingForeachFunction.FunctionCase.SCALA_FUNCTION => val scalaFn = Utils.deserialize[StreamingForeachBatchHelper.ForeachBatchFnType]( @@ -2801,7 +2802,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { StreamingForeachBatchHelper.scalaForeachBatchWrapper(scalaFn, sessionHolder) case StreamingForeachFunction.FunctionCase.FUNCTION_NOT_SET => - throw InvalidPlanInput("Unexpected") + throw InvalidPlanInput("Unexpected") // Unreachable } writer.foreachBatch(foreachBatchFn) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala index 66487e7048c..31481393777 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala @@ -18,9 +18,13 @@ package org.apache.spark.sql.connect.planner import java.util.UUID +import org.apache.spark.api.python.PythonRDD +import org.apache.spark.api.python.SimplePythonFunction +import org.apache.spark.api.python.StreamingPythonRunner import org.apache.spark.internal.Logging import org.apache.spark.sql.DataFrame import org.apache.spark.sql.connect.service.SessionHolder +import org.apache.spark.sql.connect.service.SparkConnectService /** * A helper class for handling ForeachBatch related functionality in Spark Connect servers @@ -29,23 +33,25 @@ object StreamingForeachBatchHelper extends Logging { type ForeachBatchFnType = (DataFrame, Long) => Unit + private case class FnArgsWithId(dfId: String, df: DataFrame, batchId: Long) + /** * Return a new ForeachBatch function that wraps `fn`. It sets up DataFrame cache so that the * user function can access it. The cache is cleared once ForeachBatch returns. */ - def dataFrameCachingWrapper( - fn: ForeachBatchFnType, + private def dataFrameCachingWrapper( + fn: FnArgsWithId => Unit, sessionHolder: SessionHolder): ForeachBatchFnType = { (df: DataFrame, batchId: Long) => { val dfId = UUID.randomUUID().toString log.info(s"Caching DataFrame with id $dfId") // TODO: Add query id to the log. - // TODO: Sanity check there is no other active DataFrame for this query. Need to include - // query id available in the cache for this check. + // TODO: Sanity check there is no other active DataFrame for this query. The query id + // needs to be saved in the cache for this check. sessionHolder.cacheDataFrameById(dfId, df) try { - fn(df, batchId) + fn(FnArgsWithId(dfId, df, batchId)) } finally { log.info(s"Removing DataFrame with id $dfId from the cache") sessionHolder.removeCachedDataFrame(dfId) @@ -57,13 +63,61 @@ object StreamingForeachBatchHelper extends Logging { * Handles setting up Scala remote session and other Spark Connect environment and then runs the * provided foreachBatch function `fn`. * - * HACK ALERT: This version does not atually set up Spark connect. Directly passes the - * DataFrame, so the user code actually runs with legacy DataFrame. + * HACK ALERT: This version does not actually set up Spark Connect session. Directly passes the + * DataFrame, so the user code actually runs with legacy DataFrame and session.. */ def scalaForeachBatchWrapper( fn: ForeachBatchFnType, sessionHolder: SessionHolder): ForeachBatchFnType = { // TODO: Set up Spark Connect session. Do we actually need this for the first version? - dataFrameCachingWrapper(fn, sessionHolder) + dataFrameCachingWrapper( + (args: FnArgsWithId) => { + fn(args.df, args.batchId) // dfId is not used, see hack comment above. + }, + sessionHolder) } + + /** + * Starts up Python worker and initializes it with Python function. Returns a foreachBatch + * function that sets up the session and Dataframe cache and and interacts with the Python + * worker to execute user's function. + */ + def pythonForeachBatchWrapper( + pythonFn: SimplePythonFunction, + sessionHolder: SessionHolder): ForeachBatchFnType = { + + val port = SparkConnectService.localPort + val connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}" + val runner = StreamingPythonRunner(pythonFn, connectUrl) + val (dataOut, dataIn) = runner.init(sessionHolder.sessionId) + + val foreachBatchRunnerFn: FnArgsWithId => Unit = (args: FnArgsWithId) => { + + // TODO(SPARK-44460): Support Auth credentials + // TODO(SPARK-44462): A new session id pointing to args.df.sparkSession needs to be created. + // This is because MicroBatch execution clones the session during start. + // The session attached to the foreachBatch dataframe is different from the one the one + // the query was started with. `sessionHolder` here contains the latter. + + PythonRDD.writeUTF(args.dfId, dataOut) + dataOut.writeLong(args.batchId) + dataOut.flush() + + val ret = dataIn.readInt() + log.info(s"Python foreach batch for dfId ${args.dfId} completed (ret: $ret)") + } + + dataFrameCachingWrapper(foreachBatchRunnerFn, sessionHolder) + } + + // TODO(SPARK-44433): Improve termination of Processes + // The goal is that when a query is terminated, the python process asociated with foreachBatch + // should be terminated. One way to do that is by registering stremaing query listener: + // After pythonForeachBatchWrapper() is invoked by the SparkConnectPlanner. + // At that time, we don't have the streaming queries yet. + // Planner should call back into this helper with the query id when it starts it immediately + // after. Save the query id to StreamingPythonRunner mapping. This mapping should be + // part of the SessionHolder. + // When a query is terminated, check the mapping and terminate any associated runner. + // These runners should be terminated when a session is deleted (due to timeout, etc). } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index ffb10985768..2831ae74f56 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -799,6 +799,7 @@ private[spark] object SpecialLengths { val END_OF_STREAM = -4 val NULL = -5 val START_ARROW_STREAM = -6 + val END_OF_MICRO_BATCH = -7 } private[spark] object BarrierTaskContextMessageProtocol { diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 19181bd98e1..6039f8d232b 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -106,10 +106,15 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String } createThroughDaemon() } else { - createSimpleWorker() + createSimpleWorker(workerModule) } } + /** Creates a Python worker with `pyspark.streaming_worker` module. */ + def createStreamingWorker(): (Socket, Option[Int]) = { + createSimpleWorker("pyspark.streaming_worker") + } + /** * Connect to a worker launched through pyspark/daemon.py (by default), which forks python * processes itself to avoid the high cost of forking from Java. This currently only works @@ -150,7 +155,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String /** * Launch a worker by executing worker.py (by default) directly and telling it to connect to us. */ - private def createSimpleWorker(): (Socket, Option[Int]) = { + private def createSimpleWorker(workerModule: String): (Socket, Option[Int]) = { var serverSocket: ServerSocket = null try { serverSocket = new ServerSocket(0, 1, InetAddress.getLoopbackAddress()) diff --git a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala new file mode 100644 index 00000000000..77dc88e0cfa --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.python + +import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, DataOutputStream} +import java.net.Socket + +import scala.collection.JavaConverters._ + +import org.apache.spark.SparkEnv +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.BUFFER_SIZE +import org.apache.spark.internal.config.Python.{PYTHON_AUTH_SOCKET_TIMEOUT, PYTHON_USE_DAEMON} + + +private[spark] object StreamingPythonRunner { + def apply(func: PythonFunction, connectUrl: String): StreamingPythonRunner = { + new StreamingPythonRunner(func, connectUrl) + } +} + +private[spark] class StreamingPythonRunner(func: PythonFunction, connectUrl: String) + extends Logging { + private val conf = SparkEnv.get.conf + protected val bufferSize: Int = conf.get(BUFFER_SIZE) + protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT) + + private val envVars: java.util.Map[String, String] = func.envVars + private val pythonExec: String = func.pythonExec + protected val pythonVer: String = func.pythonVer + + /** + * Initializes the Python worker for streaming functions. Sets up Spark Connect session + * to be used with the functions. + */ + def init(sessionId: String): (DataOutputStream, DataInputStream) = { + log.info(s"Initializing Python runner (session: $sessionId ,pythonExec: $pythonExec") + + val env = SparkEnv.get + + val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",") + envVars.put("SPARK_LOCAL_DIRS", localdir) + + envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString) + envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString) + conf.set(PYTHON_USE_DAEMON, false) + envVars.put("SPARK_CONNECT_LOCAL_URL", connectUrl) + + val pythonWorkerFactory = new PythonWorkerFactory(pythonExec, envVars.asScala.toMap) + val (worker: Socket, _) = pythonWorkerFactory.createStreamingWorker() + + val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) + val dataOut = new DataOutputStream(stream) + + // TODO: verify python version + + // Send sessionId + PythonRDD.writeUTF(sessionId, dataOut) + + // send the user function to python process + val command = func.command + dataOut.writeInt(command.length) + dataOut.write(command.toArray) + dataOut.flush() + + val dataIn = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) + + val resFromPython = dataIn.readInt() + log.info(s"Runner initialization returned $resFromPython") + + (dataOut, dataIn) + } +} diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 52eab1bf5f9..37a5bdd9f9f 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -58,6 +58,7 @@ from pyspark.sql.connect.plan import ( LogicalPlan, CachedLocalRelation, CachedRelation, + CachedRemoteRelation, ) from pyspark.sql.connect.readwriter import DataFrameReader from pyspark.sql.connect.streaming import DataStreamReader, StreamingQueryManager @@ -670,6 +671,14 @@ class SparkSession: copyFromLocalToFs.__doc__ = PySparkSession.copyFromLocalToFs.__doc__ + def _createRemoteDataFrame(self, remote_id: str) -> "DataFrame": + """ + In internal API to reference a runtime DataFrame on the server side. + This is used in ForeachBatch() runner, where the remote DataFrame refers to the + output of a micro batch. + """ + return DataFrame.withPlan(CachedRemoteRelation(remote_id), self) + @staticmethod def _start_connect_server(master: str, opts: Dict[str, Any]) -> None: """ diff --git a/python/pyspark/sql/connect/streaming/readwriter.py b/python/pyspark/sql/connect/streaming/readwriter.py index 156a3ba87db..c8cd408404f 100644 --- a/python/pyspark/sql/connect/streaming/readwriter.py +++ b/python/pyspark/sql/connect/streaming/readwriter.py @@ -32,7 +32,7 @@ from pyspark.sql.streaming.readwriter import ( DataStreamWriter as PySparkDataStreamWriter, ) from pyspark.sql.types import Row, StructType -from pyspark.errors import PySparkTypeError, PySparkValueError, PySparkNotImplementedError +from pyspark.errors import PySparkTypeError, PySparkValueError if TYPE_CHECKING: from pyspark.sql.connect.session import SparkSession @@ -495,14 +495,14 @@ class DataStreamWriter: foreach.__doc__ = PySparkDataStreamWriter.foreach.__doc__ - # TODO (SPARK-42944): Implement and uncomment the doc def foreachBatch(self, func: Callable[["DataFrame", int], None]) -> "DataStreamWriter": - raise PySparkNotImplementedError( - error_class="NOT_IMPLEMENTED", - message_parameters={"feature": "foreachBatch()"}, + self._write_proto.foreach_batch.python_function.command = CloudPickleSerializer().dumps( + func ) + self._write_proto.foreach_batch.python_function.python_ver = "%d.%d" % sys.version_info[:2] + return self - # foreachBatch.__doc__ = PySparkDataStreamWriter.foreachBatch.__doc__ + foreachBatch.__doc__ = PySparkDataStreamWriter.foreachBatch.__doc__ def _start_internal( self, diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_streaming_foreachBatch.py b/python/pyspark/sql/tests/connect/streaming/test_parity_streaming_foreachBatch.py new file mode 100644 index 00000000000..c4aa936a43e --- /dev/null +++ b/python/pyspark/sql/tests/connect/streaming/test_parity_streaming_foreachBatch.py @@ -0,0 +1,44 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from pyspark.sql.tests.streaming.test_streaming_foreachBatch import StreamingTestsForeachBatchMixin +from pyspark.testing.connectutils import ReusedConnectTestCase + + +class StreamingForeachBatchParityTests(StreamingTestsForeachBatchMixin, ReusedConnectTestCase): + @unittest.skip("SPARK-44463: Error handling needs improvement in connect foreachBatch") + def test_streaming_foreachBatch_propagates_python_errors(self): + super().test_streaming_foreachBatch_propagates_python_errors + + @unittest.skip("This seems specific to py4j and pinned threads. The intention is unclear") + def test_streaming_foreachBatch_graceful_stop(self): + super().test_streaming_foreachBatch_graceful_stop() + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.connect.streaming.test_parity_streaming_foreachBatch import * # noqa: F401,E501 + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py b/python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py index 7e5720e4299..d4e185c3d85 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py +++ b/python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py @@ -20,40 +20,40 @@ import time from pyspark.testing.sqlutils import ReusedSQLTestCase -class StreamingTestsForeachBatch(ReusedSQLTestCase): +class StreamingTestsForeachBatchMixin: def test_streaming_foreachBatch(self): q = None - collected = dict() def collectBatch(batch_df, batch_id): - collected[batch_id] = batch_df.collect() + batch_df.createOrReplaceGlobalTempView("test_view") try: df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") q = df.writeStream.foreachBatch(collectBatch).start() q.processAllAvailable() - self.assertTrue(0 in collected) - self.assertTrue(len(collected[0]), 2) + collected = self.spark.sql("select * from global_temp.test_view").collect() + self.assertTrue(len(collected), 2) finally: if q: q.stop() def test_streaming_foreachBatch_tempview(self): q = None - collected = dict() def collectBatch(batch_df, batch_id): batch_df.createOrReplaceTempView("updates") # it should use the spark session within given DataFrame, as microbatch execution will # clone the session which is no longer same with the session used to start the # streaming query - collected[batch_id] = batch_df.sparkSession.sql("SELECT * FROM updates").collect() + assert len(batch_df.sparkSession.sql("SELECT * FROM updates").collect()) == 2 + # Write to a global view verify on the repl/client side. + batch_df.createOrReplaceGlobalTempView("temp_view") try: df = self.spark.readStream.format("text").load("python/test_support/sql/streaming") q = df.writeStream.foreachBatch(collectBatch).start() q.processAllAvailable() - self.assertTrue(0 in collected) + collected = self.spark.sql("SELECT * FROM global_temp.temp_view").collect() self.assertTrue(len(collected[0]), 2) finally: if q: @@ -89,6 +89,10 @@ class StreamingTestsForeachBatch(ReusedSQLTestCase): self.assertIsNone(q.exception(), "No exception has to be propagated.") +class StreamingTestsForeachBatch(StreamingTestsForeachBatchMixin, ReusedSQLTestCase): + pass + + if __name__ == "__main__": import unittest from pyspark.sql.tests.streaming.test_streaming_foreachBatch import * # noqa: F401 diff --git a/python/pyspark/streaming_worker.py b/python/pyspark/streaming_worker.py new file mode 100644 index 00000000000..490bae44d99 --- /dev/null +++ b/python/pyspark/streaming_worker.py @@ -0,0 +1,78 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +A worker for streaming foreachBatch and query listener in Spark Connect. +""" +import os + +from pyspark.java_gateway import local_connect_and_auth +from pyspark.serializers import ( + write_int, + read_long, + UTF8Deserializer, + CPickleSerializer, +) +from pyspark import worker +from pyspark.sql import SparkSession + +pickleSer = CPickleSerializer() +utf8_deserializer = UTF8Deserializer() + + +def main(infile, outfile): # type: ignore[no-untyped-def] + log_name = "Streaming ForeachBatch worker" + connect_url = os.environ["SPARK_CONNECT_LOCAL_URL"] + sessionId = utf8_deserializer.loads(infile) + + print(f"{log_name} is starting with url {connect_url} and sessionId {sessionId}.") + + sparkConnectSession = SparkSession.builder.remote(connect_url).getOrCreate() + sparkConnectSession._client._session_id = sessionId + + # TODO(SPARK-44460): Pass credentials. + # TODO(SPARK-44461): Enable Process Isolation + + func = worker.read_command(pickleSer, infile) + write_int(0, outfile) # Indicate successful initialization + + outfile.flush() + + def process(dfId, batchId): # type: ignore[no-untyped-def] + print(f"{log_name} Started batch {batchId} with DF id {dfId}") + batchDf = sparkConnectSession._createRemoteDataFrame(dfId) + func(batchDf, batchId) + print(f"{log_name} Completed batch {batchId} with DF id {dfId}") + + while True: + dfRefId = utf8_deserializer.loads(infile) + batchId = read_long(infile) + process(dfRefId, int(batchId)) # TODO(SPARK-44463): Propagate error to the user. + write_int(0, outfile) + outfile.flush() + + +if __name__ == "__main__": + print("Starting streaming worker") + + # Read information about how to connect back to the JVM from the environment. + java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) + auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] + (sock_file, _) = local_connect_and_auth(java_port, auth_secret) + write_int(os.getpid(), sock_file) + sock_file.flush() + main(sock_file, sock_file) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org