This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new 5bfaa71d7bc [SPARK-42944][SS][PYTHON] Streaming ForeachBatch in Python
5bfaa71d7bc is described below

commit 5bfaa71d7bc63a19c73bc0208eccf1d68dbf6ac7
Author: Raghu Angadi <raghu.ang...@databricks.com>
AuthorDate: Wed Jul 19 09:04:05 2023 +0900

    [SPARK-42944][SS][PYTHON] Streaming ForeachBatch in Python
    
    Adds `foreachBatch()` in Python. This adds a new runner 
`StreamingPythonRunner`.
    Note that this PR focuses on core functionality and includes TODO for 
followup improvements (will update with jira tickets where missing).
    
    Included more inline comments to help with the review.
    
    ### What changes were proposed in this pull request?
    Adds support for foreachBatch() in Spark connect.
    
    ### Why are the changes needed?
      - Manual tests
      - Unit tests:
         - The tests are updated to use a global temp view, rather than shared 
variable since connect version of the function runs on the server side.
    
    ### Does this PR introduce _any_ user-facing change?
    
    ### How was this patch tested?
    
    Closes #42035 from rangadi/feb-py.
    
    Authored-by: Raghu Angadi <raghu.ang...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
    (cherry picked from commit d93f6d145142ec15a96b0a3bfbbaa044c4b725e9)
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../spark/sql/streaming/StreamingQuerySuite.scala  |  2 +-
 .../sql/connect/planner/SparkConnectPlanner.scala  |  5 +-
 .../planner/StreamingForeachBatchHelper.scala      | 70 +++++++++++++++--
 .../org/apache/spark/api/python/PythonRunner.scala |  1 +
 .../spark/api/python/PythonWorkerFactory.scala     |  9 ++-
 .../spark/api/python/StreamingPythonRunner.scala   | 88 ++++++++++++++++++++++
 python/pyspark/sql/connect/session.py              |  9 +++
 python/pyspark/sql/connect/streaming/readwriter.py | 12 +--
 .../test_parity_streaming_foreachBatch.py          | 44 +++++++++++
 .../tests/streaming/test_streaming_foreachBatch.py | 20 +++--
 python/pyspark/streaming_worker.py                 | 78 +++++++++++++++++++
 11 files changed, 311 insertions(+), 27 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
index 1287176d76e..91d744b9e48 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
@@ -335,7 +335,7 @@ class StreamingQuerySuite extends QueryTest with SQLHelper 
with Logging {
         .start()
 
       eventually(timeout(30.seconds)) { // Wait for first progress.
-        assert(q.lastProgress != null)
+        assert(q.lastProgress != null, "Failed to make progress")
         assert(q.lastProgress.numInputRows > 0)
       }
 
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 39cb4c1b972..92a9524f67a 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -2792,7 +2792,8 @@ class SparkConnectPlanner(val sessionHolder: 
SessionHolder) extends Logging {
     if (writeOp.hasForeachBatch) {
       val foreachBatchFn = writeOp.getForeachBatch.getFunctionCase match {
         case StreamingForeachFunction.FunctionCase.PYTHON_FUNCTION =>
-          throw InvalidPlanInput("Python ForeachBatch is not supported yet. 
WIP.")
+          val pythonFn = 
transformPythonFunction(writeOp.getForeachBatch.getPythonFunction)
+          StreamingForeachBatchHelper.pythonForeachBatchWrapper(pythonFn, 
sessionHolder)
 
         case StreamingForeachFunction.FunctionCase.SCALA_FUNCTION =>
           val scalaFn = 
Utils.deserialize[StreamingForeachBatchHelper.ForeachBatchFnType](
@@ -2801,7 +2802,7 @@ class SparkConnectPlanner(val sessionHolder: 
SessionHolder) extends Logging {
           StreamingForeachBatchHelper.scalaForeachBatchWrapper(scalaFn, 
sessionHolder)
 
         case StreamingForeachFunction.FunctionCase.FUNCTION_NOT_SET =>
-          throw InvalidPlanInput("Unexpected")
+          throw InvalidPlanInput("Unexpected") // Unreachable
       }
 
       writer.foreachBatch(foreachBatchFn)
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
index 66487e7048c..31481393777 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
@@ -18,9 +18,13 @@ package org.apache.spark.sql.connect.planner
 
 import java.util.UUID
 
+import org.apache.spark.api.python.PythonRDD
+import org.apache.spark.api.python.SimplePythonFunction
+import org.apache.spark.api.python.StreamingPythonRunner
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.connect.service.SessionHolder
+import org.apache.spark.sql.connect.service.SparkConnectService
 
 /**
  * A helper class for handling ForeachBatch related functionality in Spark 
Connect servers
@@ -29,23 +33,25 @@ object StreamingForeachBatchHelper extends Logging {
 
   type ForeachBatchFnType = (DataFrame, Long) => Unit
 
+  private case class FnArgsWithId(dfId: String, df: DataFrame, batchId: Long)
+
   /**
    * Return a new ForeachBatch function that wraps `fn`. It sets up DataFrame 
cache so that the
    * user function can access it. The cache is cleared once ForeachBatch 
returns.
    */
-  def dataFrameCachingWrapper(
-      fn: ForeachBatchFnType,
+  private def dataFrameCachingWrapper(
+      fn: FnArgsWithId => Unit,
       sessionHolder: SessionHolder): ForeachBatchFnType = { (df: DataFrame, 
batchId: Long) =>
     {
       val dfId = UUID.randomUUID().toString
       log.info(s"Caching DataFrame with id $dfId") // TODO: Add query id to 
the log.
 
-      // TODO: Sanity check there is no other active DataFrame for this query. 
Need to include
-      //       query id available in the cache for this check.
+      // TODO: Sanity check there is no other active DataFrame for this query. 
The query id
+      //       needs to be saved in the cache for this check.
 
       sessionHolder.cacheDataFrameById(dfId, df)
       try {
-        fn(df, batchId)
+        fn(FnArgsWithId(dfId, df, batchId))
       } finally {
         log.info(s"Removing DataFrame with id $dfId from the cache")
         sessionHolder.removeCachedDataFrame(dfId)
@@ -57,13 +63,61 @@ object StreamingForeachBatchHelper extends Logging {
    * Handles setting up Scala remote session and other Spark Connect 
environment and then runs the
    * provided foreachBatch function `fn`.
    *
-   * HACK ALERT: This version does not atually set up Spark connect. Directly 
passes the
-   * DataFrame, so the user code actually runs with legacy DataFrame.
+   * HACK ALERT: This version does not actually set up Spark Connect session. 
Directly passes the
+   * DataFrame, so the user code actually runs with legacy DataFrame and 
session..
    */
   def scalaForeachBatchWrapper(
       fn: ForeachBatchFnType,
       sessionHolder: SessionHolder): ForeachBatchFnType = {
     // TODO: Set up Spark Connect session. Do we actually need this for the 
first version?
-    dataFrameCachingWrapper(fn, sessionHolder)
+    dataFrameCachingWrapper(
+      (args: FnArgsWithId) => {
+        fn(args.df, args.batchId) // dfId is not used, see hack comment above.
+      },
+      sessionHolder)
   }
+
+  /**
+   * Starts up Python worker and initializes it with Python function. Returns 
a foreachBatch
+   * function that sets up the session and Dataframe cache and and interacts 
with the Python
+   * worker to execute user's function.
+   */
+  def pythonForeachBatchWrapper(
+      pythonFn: SimplePythonFunction,
+      sessionHolder: SessionHolder): ForeachBatchFnType = {
+
+    val port = SparkConnectService.localPort
+    val connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}"
+    val runner = StreamingPythonRunner(pythonFn, connectUrl)
+    val (dataOut, dataIn) = runner.init(sessionHolder.sessionId)
+
+    val foreachBatchRunnerFn: FnArgsWithId => Unit = (args: FnArgsWithId) => {
+
+      // TODO(SPARK-44460): Support Auth credentials
+      // TODO(SPARK-44462): A new session id pointing to args.df.sparkSession 
needs to be created.
+      //     This is because MicroBatch execution clones the session during 
start.
+      //     The session attached to the foreachBatch dataframe is different 
from the one the one
+      //     the query was started with. `sessionHolder` here contains the 
latter.
+
+      PythonRDD.writeUTF(args.dfId, dataOut)
+      dataOut.writeLong(args.batchId)
+      dataOut.flush()
+
+      val ret = dataIn.readInt()
+      log.info(s"Python foreach batch for dfId ${args.dfId} completed (ret: 
$ret)")
+    }
+
+    dataFrameCachingWrapper(foreachBatchRunnerFn, sessionHolder)
+  }
+
+  // TODO(SPARK-44433): Improve termination of Processes
+  //   The goal is that when a query is terminated, the python process 
asociated with foreachBatch
+  //   should be terminated. One way to do that is by registering stremaing 
query listener:
+  //   After pythonForeachBatchWrapper() is invoked by the SparkConnectPlanner.
+  //   At that time, we don't have the streaming queries yet.
+  //   Planner should call back into this helper with the query id when it 
starts it immediately
+  //   after. Save the query id to StreamingPythonRunner mapping. This mapping 
should be
+  //   part of the SessionHolder.
+  //   When a query is terminated, check the mapping and terminate any 
associated runner.
+  //   These runners should be terminated when a session is deleted (due to 
timeout, etc).
 }
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index ffb10985768..2831ae74f56 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -799,6 +799,7 @@ private[spark] object SpecialLengths {
   val END_OF_STREAM = -4
   val NULL = -5
   val START_ARROW_STREAM = -6
+  val END_OF_MICRO_BATCH = -7
 }
 
 private[spark] object BarrierTaskContextMessageProtocol {
diff --git 
a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
index 19181bd98e1..6039f8d232b 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
@@ -106,10 +106,15 @@ private[spark] class PythonWorkerFactory(pythonExec: 
String, envVars: Map[String
       }
       createThroughDaemon()
     } else {
-      createSimpleWorker()
+      createSimpleWorker(workerModule)
     }
   }
 
+  /** Creates a Python worker with `pyspark.streaming_worker` module. */
+  def createStreamingWorker(): (Socket, Option[Int]) = {
+    createSimpleWorker("pyspark.streaming_worker")
+  }
+
   /**
    * Connect to a worker launched through pyspark/daemon.py (by default), 
which forks python
    * processes itself to avoid the high cost of forking from Java. This 
currently only works
@@ -150,7 +155,7 @@ private[spark] class PythonWorkerFactory(pythonExec: 
String, envVars: Map[String
   /**
    * Launch a worker by executing worker.py (by default) directly and telling 
it to connect to us.
    */
-  private def createSimpleWorker(): (Socket, Option[Int]) = {
+  private def createSimpleWorker(workerModule: String): (Socket, Option[Int]) 
= {
     var serverSocket: ServerSocket = null
     try {
       serverSocket = new ServerSocket(0, 1, InetAddress.getLoopbackAddress())
diff --git 
a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala 
b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala
new file mode 100644
index 00000000000..77dc88e0cfa
--- /dev/null
+++ 
b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.python
+
+import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, 
DataOutputStream}
+import java.net.Socket
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.SparkEnv
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config.BUFFER_SIZE
+import org.apache.spark.internal.config.Python.{PYTHON_AUTH_SOCKET_TIMEOUT, 
PYTHON_USE_DAEMON}
+
+
+private[spark] object StreamingPythonRunner {
+  def apply(func: PythonFunction, connectUrl: String): StreamingPythonRunner = 
{
+    new StreamingPythonRunner(func, connectUrl)
+  }
+}
+
+private[spark] class StreamingPythonRunner(func: PythonFunction, connectUrl: 
String)
+  extends Logging {
+  private val conf = SparkEnv.get.conf
+  protected val bufferSize: Int = conf.get(BUFFER_SIZE)
+  protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT)
+
+  private val envVars: java.util.Map[String, String] = func.envVars
+  private val pythonExec: String = func.pythonExec
+  protected val pythonVer: String = func.pythonVer
+
+  /**
+   * Initializes the Python worker for streaming functions. Sets up Spark 
Connect session
+   * to be used with the functions.
+   */
+  def init(sessionId: String): (DataOutputStream, DataInputStream) = {
+    log.info(s"Initializing Python runner (session: $sessionId ,pythonExec: 
$pythonExec")
+
+    val env = SparkEnv.get
+
+    val localdir = env.blockManager.diskBlockManager.localDirs.map(f => 
f.getPath()).mkString(",")
+    envVars.put("SPARK_LOCAL_DIRS", localdir)
+
+    envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString)
+    envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString)
+    conf.set(PYTHON_USE_DAEMON, false)
+    envVars.put("SPARK_CONNECT_LOCAL_URL", connectUrl)
+
+    val pythonWorkerFactory = new PythonWorkerFactory(pythonExec, 
envVars.asScala.toMap)
+    val (worker: Socket, _) = pythonWorkerFactory.createStreamingWorker()
+
+    val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
+    val dataOut = new DataOutputStream(stream)
+
+    // TODO: verify python version
+
+    // Send sessionId
+    PythonRDD.writeUTF(sessionId, dataOut)
+
+    // send the user function to python process
+    val command = func.command
+    dataOut.writeInt(command.length)
+    dataOut.write(command.toArray)
+    dataOut.flush()
+
+    val dataIn = new DataInputStream(new 
BufferedInputStream(worker.getInputStream, bufferSize))
+
+    val resFromPython = dataIn.readInt()
+    log.info(s"Runner initialization returned $resFromPython")
+
+    (dataOut, dataIn)
+  }
+}
diff --git a/python/pyspark/sql/connect/session.py 
b/python/pyspark/sql/connect/session.py
index 52eab1bf5f9..37a5bdd9f9f 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -58,6 +58,7 @@ from pyspark.sql.connect.plan import (
     LogicalPlan,
     CachedLocalRelation,
     CachedRelation,
+    CachedRemoteRelation,
 )
 from pyspark.sql.connect.readwriter import DataFrameReader
 from pyspark.sql.connect.streaming import DataStreamReader, 
StreamingQueryManager
@@ -670,6 +671,14 @@ class SparkSession:
 
     copyFromLocalToFs.__doc__ = PySparkSession.copyFromLocalToFs.__doc__
 
+    def _createRemoteDataFrame(self, remote_id: str) -> "DataFrame":
+        """
+        In internal API to reference a runtime DataFrame on the server side.
+        This is used in ForeachBatch() runner, where the remote DataFrame 
refers to the
+        output of a micro batch.
+        """
+        return DataFrame.withPlan(CachedRemoteRelation(remote_id), self)
+
     @staticmethod
     def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
         """
diff --git a/python/pyspark/sql/connect/streaming/readwriter.py 
b/python/pyspark/sql/connect/streaming/readwriter.py
index 156a3ba87db..c8cd408404f 100644
--- a/python/pyspark/sql/connect/streaming/readwriter.py
+++ b/python/pyspark/sql/connect/streaming/readwriter.py
@@ -32,7 +32,7 @@ from pyspark.sql.streaming.readwriter import (
     DataStreamWriter as PySparkDataStreamWriter,
 )
 from pyspark.sql.types import Row, StructType
-from pyspark.errors import PySparkTypeError, PySparkValueError, 
PySparkNotImplementedError
+from pyspark.errors import PySparkTypeError, PySparkValueError
 
 if TYPE_CHECKING:
     from pyspark.sql.connect.session import SparkSession
@@ -495,14 +495,14 @@ class DataStreamWriter:
 
     foreach.__doc__ = PySparkDataStreamWriter.foreach.__doc__
 
-    # TODO (SPARK-42944): Implement and uncomment the doc
     def foreachBatch(self, func: Callable[["DataFrame", int], None]) -> 
"DataStreamWriter":
-        raise PySparkNotImplementedError(
-            error_class="NOT_IMPLEMENTED",
-            message_parameters={"feature": "foreachBatch()"},
+        self._write_proto.foreach_batch.python_function.command = 
CloudPickleSerializer().dumps(
+            func
         )
+        self._write_proto.foreach_batch.python_function.python_ver = "%d.%d" % 
sys.version_info[:2]
+        return self
 
-    # foreachBatch.__doc__ = PySparkDataStreamWriter.foreachBatch.__doc__
+    foreachBatch.__doc__ = PySparkDataStreamWriter.foreachBatch.__doc__
 
     def _start_internal(
         self,
diff --git 
a/python/pyspark/sql/tests/connect/streaming/test_parity_streaming_foreachBatch.py
 
b/python/pyspark/sql/tests/connect/streaming/test_parity_streaming_foreachBatch.py
new file mode 100644
index 00000000000..c4aa936a43e
--- /dev/null
+++ 
b/python/pyspark/sql/tests/connect/streaming/test_parity_streaming_foreachBatch.py
@@ -0,0 +1,44 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from pyspark.sql.tests.streaming.test_streaming_foreachBatch import 
StreamingTestsForeachBatchMixin
+from pyspark.testing.connectutils import ReusedConnectTestCase
+
+
+class StreamingForeachBatchParityTests(StreamingTestsForeachBatchMixin, 
ReusedConnectTestCase):
+    @unittest.skip("SPARK-44463: Error handling needs improvement in connect 
foreachBatch")
+    def test_streaming_foreachBatch_propagates_python_errors(self):
+        super().test_streaming_foreachBatch_propagates_python_errors
+
+    @unittest.skip("This seems specific to py4j and pinned threads. The 
intention is unclear")
+    def test_streaming_foreachBatch_graceful_stop(self):
+        super().test_streaming_foreachBatch_graceful_stop()
+
+
+if __name__ == "__main__":
+    import unittest
+    from 
pyspark.sql.tests.connect.streaming.test_parity_streaming_foreachBatch import * 
 # noqa: F401,E501
+
+    try:
+        import xmlrunner  # type: ignore[import]
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py 
b/python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py
index 7e5720e4299..d4e185c3d85 100644
--- a/python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py
+++ b/python/pyspark/sql/tests/streaming/test_streaming_foreachBatch.py
@@ -20,40 +20,40 @@ import time
 from pyspark.testing.sqlutils import ReusedSQLTestCase
 
 
-class StreamingTestsForeachBatch(ReusedSQLTestCase):
+class StreamingTestsForeachBatchMixin:
     def test_streaming_foreachBatch(self):
         q = None
-        collected = dict()
 
         def collectBatch(batch_df, batch_id):
-            collected[batch_id] = batch_df.collect()
+            batch_df.createOrReplaceGlobalTempView("test_view")
 
         try:
             df = 
self.spark.readStream.format("text").load("python/test_support/sql/streaming")
             q = df.writeStream.foreachBatch(collectBatch).start()
             q.processAllAvailable()
-            self.assertTrue(0 in collected)
-            self.assertTrue(len(collected[0]), 2)
+            collected = self.spark.sql("select * from 
global_temp.test_view").collect()
+            self.assertTrue(len(collected), 2)
         finally:
             if q:
                 q.stop()
 
     def test_streaming_foreachBatch_tempview(self):
         q = None
-        collected = dict()
 
         def collectBatch(batch_df, batch_id):
             batch_df.createOrReplaceTempView("updates")
             # it should use the spark session within given DataFrame, as 
microbatch execution will
             # clone the session which is no longer same with the session used 
to start the
             # streaming query
-            collected[batch_id] = batch_df.sparkSession.sql("SELECT * FROM 
updates").collect()
+            assert len(batch_df.sparkSession.sql("SELECT * FROM 
updates").collect()) == 2
+            # Write to a global view verify on the repl/client side.
+            batch_df.createOrReplaceGlobalTempView("temp_view")
 
         try:
             df = 
self.spark.readStream.format("text").load("python/test_support/sql/streaming")
             q = df.writeStream.foreachBatch(collectBatch).start()
             q.processAllAvailable()
-            self.assertTrue(0 in collected)
+            collected = self.spark.sql("SELECT * FROM 
global_temp.temp_view").collect()
             self.assertTrue(len(collected[0]), 2)
         finally:
             if q:
@@ -89,6 +89,10 @@ class StreamingTestsForeachBatch(ReusedSQLTestCase):
         self.assertIsNone(q.exception(), "No exception has to be propagated.")
 
 
+class StreamingTestsForeachBatch(StreamingTestsForeachBatchMixin, 
ReusedSQLTestCase):
+    pass
+
+
 if __name__ == "__main__":
     import unittest
     from pyspark.sql.tests.streaming.test_streaming_foreachBatch import *  # 
noqa: F401
diff --git a/python/pyspark/streaming_worker.py 
b/python/pyspark/streaming_worker.py
new file mode 100644
index 00000000000..490bae44d99
--- /dev/null
+++ b/python/pyspark/streaming_worker.py
@@ -0,0 +1,78 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+A worker for streaming foreachBatch and query listener in Spark Connect.
+"""
+import os
+
+from pyspark.java_gateway import local_connect_and_auth
+from pyspark.serializers import (
+    write_int,
+    read_long,
+    UTF8Deserializer,
+    CPickleSerializer,
+)
+from pyspark import worker
+from pyspark.sql import SparkSession
+
+pickleSer = CPickleSerializer()
+utf8_deserializer = UTF8Deserializer()
+
+
+def main(infile, outfile):  # type: ignore[no-untyped-def]
+    log_name = "Streaming ForeachBatch worker"
+    connect_url = os.environ["SPARK_CONNECT_LOCAL_URL"]
+    sessionId = utf8_deserializer.loads(infile)
+
+    print(f"{log_name} is starting with url {connect_url} and sessionId 
{sessionId}.")
+
+    sparkConnectSession = 
SparkSession.builder.remote(connect_url).getOrCreate()
+    sparkConnectSession._client._session_id = sessionId
+
+    # TODO(SPARK-44460): Pass credentials.
+    # TODO(SPARK-44461): Enable Process Isolation
+
+    func = worker.read_command(pickleSer, infile)
+    write_int(0, outfile)  # Indicate successful initialization
+
+    outfile.flush()
+
+    def process(dfId, batchId):  # type: ignore[no-untyped-def]
+        print(f"{log_name} Started batch {batchId} with DF id {dfId}")
+        batchDf = sparkConnectSession._createRemoteDataFrame(dfId)
+        func(batchDf, batchId)
+        print(f"{log_name} Completed batch {batchId} with DF id {dfId}")
+
+    while True:
+        dfRefId = utf8_deserializer.loads(infile)
+        batchId = read_long(infile)
+        process(dfRefId, int(batchId))  # TODO(SPARK-44463): Propagate error 
to the user.
+        write_int(0, outfile)
+        outfile.flush()
+
+
+if __name__ == "__main__":
+    print("Starting streaming worker")
+
+    # Read information about how to connect back to the JVM from the 
environment.
+    java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
+    auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
+    (sock_file, _) = local_connect_and_auth(java_port, auth_secret)
+    write_int(os.getpid(), sock_file)
+    sock_file.flush()
+    main(sock_file, sock_file)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to