This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new e7fc4003b246 [SPARK-47812][CONNECT] Support Serialization of
SparkSession for ForEachBatch worker
e7fc4003b246 is described below
commit e7fc4003b246bab743ab82d9e7bb77c0e2e5946e
Author: Martin Grund <[email protected]>
AuthorDate: Sat Apr 13 10:30:23 2024 +0900
[SPARK-47812][CONNECT] Support Serialization of SparkSession for
ForEachBatch worker
### What changes were proposed in this pull request?
This patch adds support to register custom dispatch handlers when
serializing objects using the provided Cloudpickle library. This is necessary
to provide compatibility when executing ForEachBatch functions in structured
streaming.
A typical example for this behavior is the following test case:
```python
def curried_function(df):
def inner(batch_df, batch_id):
df.createOrReplaceTempView("updates")
batch_df.createOrReplaceTempView("batch_updates")
return inner
df =
spark.readStream.format("text").load("python/test_support/sql/streaming")
other_df = self.spark.range(100)
df.writeStream.foreachBatch(curried_function(other_df)).start()
```
Here we curry a DataFrame into the function called during ForEachBatch and
effectively passing state. Until now, serializing DataFrames and SparkSessions
in Spark Connect was not possible since the SparkSession carries the open GPRC
connection and the DataFrame itself overrides certain magic methods that make
pickling fail.
To make serializing Spark Sessions possible, we register a custom session
constructor, that simply returns the current active session, during the
serialization of the ForEachBatch function. Now, when the ForEachBatch worker
starts the execution it already creates and registers an active SparkSession.
To serialize and reconstruct the DataFrame we simply have to pass in the
session and the plan, the remaining attributes do not carry a permanent state.
To avoid modifying any global behavior, the serialization handlers are not
registered for all cases but only when the ForEachBatch and ForEach handlers
are called. This is to make sure that we don't unexpectedly change behavior.
### Why are the changes needed?
Compatibility and Ease of Use
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Added and updated tests
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #46002 from grundprinzip/SPARK-47812.
Lead-authored-by: Martin Grund <[email protected]>
Co-authored-by: Martin Grund <[email protected]>
Co-authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/sql/connect/dataframe.py | 22 +++++++
python/pyspark/sql/connect/session.py | 37 ++++++++++++
.../streaming/worker/foreach_batch_worker.py | 15 ++++-
.../connect/streaming/worker/listener_worker.py | 15 ++++-
.../connect/streaming/test_parity_foreach_batch.py | 70 +++++++++++++++++-----
.../connect/streaming/test_parity_listener.py | 23 ++-----
.../pyspark/sql/tests/connect/test_parity_udtf.py | 18 +++++-
7 files changed, 163 insertions(+), 37 deletions(-)
diff --git a/python/pyspark/sql/connect/dataframe.py
b/python/pyspark/sql/connect/dataframe.py
index 1dddcc078810..f0dc412760a4 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -122,6 +122,28 @@ class DataFrame:
self._support_repr_html = False
self._cached_schema: Optional[StructType] = None
+ def __reduce__(self) -> Tuple:
+ """
+ Custom method for serializing the DataFrame object using Pickle. Since
the DataFrame
+ overrides "__getattr__" method, the default serialization method does
not work.
+
+ Returns
+ -------
+ The tuple containing the information needed to reconstruct the object.
+
+ """
+ return (
+ DataFrame,
+ (
+ self._plan,
+ self._session,
+ ),
+ {
+ "_support_repr_html": self._support_repr_html,
+ "_cached_schema": self._cached_schema,
+ },
+ )
+
def __repr__(self) -> str:
if not self._support_repr_html:
(
diff --git a/python/pyspark/sql/connect/session.py
b/python/pyspark/sql/connect/session.py
index 07fe8a62f082..3be6c83cf13b 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -96,6 +96,7 @@ from pyspark.errors import (
PySparkRuntimeError,
PySparkValueError,
PySparkTypeError,
+ PySparkAssertionError,
)
if TYPE_CHECKING:
@@ -288,6 +289,26 @@ class SparkSession:
def getActiveSession(cls) -> Optional["SparkSession"]:
return getattr(cls._active_session, "session", None)
+ @classmethod
+ def _getActiveSessionIfMatches(cls, session_id: str) -> "SparkSession":
+ """
+ Internal use only. This method is called from the custom handler
+ generated by __reduce__. To avoid serializing a WeakRef, we create a
+ custom classmethod to instantiate the SparkSession.
+ """
+ session = SparkSession.getActiveSession()
+ if session is None:
+ raise PySparkRuntimeError(
+ error_class="NO_ACTIVE_SESSION",
+ message_parameters={},
+ )
+ if session._session_id != session_id:
+ raise PySparkAssertionError(
+ "Expected session ID does not match active session ID: "
+ f"{session_id} != {session._session_id}"
+ )
+ return session
+
getActiveSession.__doc__ = PySparkSession.getActiveSession.__doc__
@classmethod
@@ -1034,6 +1055,22 @@ class SparkSession:
profile.__doc__ = PySparkSession.profile.__doc__
+ def __reduce__(self) -> Tuple:
+ """
+ This method is called when the object is pickled. It returns a tuple
of the object's
+ constructor function, arguments to it and the local state of the
object.
+ This function is supposed to only be used when the active spark
session that is pickled
+ is the same active spark session that is unpickled.
+ """
+
+ def creator(old_session_id: str) -> "SparkSession":
+ # We cannot perform the checks for session matching here because
accessing the
+ # session ID property causes the serialization of a WeakRef and in
turn breaks
+ # the serialization.
+ return SparkSession._getActiveSessionIfMatches(old_session_id)
+
+ return creator, (self._session_id,)
+
SparkSession.__doc__ = PySparkSession.__doc__
diff --git
a/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py
b/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py
index c4cf52b9996d..92ed7a4aaff5 100644
--- a/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py
+++ b/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py
@@ -29,7 +29,7 @@ from pyspark.serializers import (
CPickleSerializer,
)
from pyspark import worker
-from pyspark.sql import SparkSession
+from pyspark.sql.connect.session import SparkSession
from pyspark.util import handle_worker_exception
from typing import IO
from pyspark.worker_util import check_python_version
@@ -38,9 +38,16 @@ pickle_ser = CPickleSerializer()
utf8_deserializer = UTF8Deserializer()
+spark = None
+
+
def main(infile: IO, outfile: IO) -> None:
+ global spark
check_python_version(infile)
+ # Enable Spark Connect Mode
+ os.environ["SPARK_CONNECT_MODE_ENABLED"] = "1"
+
connect_url = os.environ["SPARK_CONNECT_LOCAL_URL"]
session_id = utf8_deserializer.loads(infile)
@@ -49,8 +56,11 @@ def main(infile: IO, outfile: IO) -> None:
f"url {connect_url} and sessionId {session_id}."
)
+ # To attach to the existing SparkSession, we're setting the session_id in
the URL.
+ connect_url = connect_url + ";session_id=" + session_id
spark_connect_session =
SparkSession.builder.remote(connect_url).getOrCreate()
- spark_connect_session._client._session_id = session_id # type:
ignore[attr-defined]
+ assert spark_connect_session.session_id == session_id
+ spark = spark_connect_session
# TODO(SPARK-44461): Enable Process Isolation
@@ -62,6 +72,7 @@ def main(infile: IO, outfile: IO) -> None:
log_name = "Streaming ForeachBatch worker"
def process(df_id, batch_id): # type: ignore[no-untyped-def]
+ global spark
print(f"{log_name} Started batch {batch_id} with DF id {df_id}")
batch_df = spark_connect_session._create_remote_dataframe(df_id)
func(batch_df, batch_id)
diff --git a/python/pyspark/sql/connect/streaming/worker/listener_worker.py
b/python/pyspark/sql/connect/streaming/worker/listener_worker.py
index 69e0d8a46248..d3efb5894fc0 100644
--- a/python/pyspark/sql/connect/streaming/worker/listener_worker.py
+++ b/python/pyspark/sql/connect/streaming/worker/listener_worker.py
@@ -30,7 +30,7 @@ from pyspark.serializers import (
CPickleSerializer,
)
from pyspark import worker
-from pyspark.sql import SparkSession
+from pyspark.sql.connect.session import SparkSession
from pyspark.util import handle_worker_exception
from typing import IO
@@ -46,9 +46,16 @@ pickle_ser = CPickleSerializer()
utf8_deserializer = UTF8Deserializer()
+spark = None
+
+
def main(infile: IO, outfile: IO) -> None:
+ global spark
check_python_version(infile)
+ # Enable Spark Connect Mode
+ os.environ["SPARK_CONNECT_MODE_ENABLED"] = "1"
+
connect_url = os.environ["SPARK_CONNECT_LOCAL_URL"]
session_id = utf8_deserializer.loads(infile)
@@ -57,8 +64,11 @@ def main(infile: IO, outfile: IO) -> None:
f"url {connect_url} and sessionId {session_id}."
)
+ # To attach to the existing SparkSession, we're setting the session_id in
the URL.
+ connect_url = connect_url + ";session_id=" + session_id
spark_connect_session =
SparkSession.builder.remote(connect_url).getOrCreate()
- spark_connect_session._client._session_id = session_id # type:
ignore[attr-defined]
+ assert spark_connect_session.session_id == session_id
+ spark = spark_connect_session
# TODO(SPARK-44461): Enable Process Isolation
@@ -71,6 +81,7 @@ def main(infile: IO, outfile: IO) -> None:
assert listener.spark == spark_connect_session
def process(listener_event_str, listener_event_type): # type:
ignore[no-untyped-def]
+ global spark
listener_event = json.loads(listener_event_str)
if listener_event_type == 0:
listener.onQueryStarted(QueryStartedEvent.fromJson(listener_event))
diff --git
a/python/pyspark/sql/tests/connect/streaming/test_parity_foreach_batch.py
b/python/pyspark/sql/tests/connect/streaming/test_parity_foreach_batch.py
index 30f7bb8c2df9..4598cbbdca4e 100644
--- a/python/pyspark/sql/tests/connect/streaming/test_parity_foreach_batch.py
+++ b/python/pyspark/sql/tests/connect/streaming/test_parity_foreach_batch.py
@@ -30,33 +30,73 @@ class
StreamingForeachBatchParityTests(StreamingTestsForeachBatchMixin, ReusedCo
def test_streaming_foreach_batch_graceful_stop(self):
super().test_streaming_foreach_batch_graceful_stop()
+ def test_nested_dataframes(self):
+ def curried_function(df):
+ def inner(batch_df, batch_id):
+ df.createOrReplaceTempView("updates")
+ batch_df.createOrReplaceTempView("batch_updates")
+
+ return inner
+
+ try:
+ df =
self.spark.readStream.format("text").load("python/test_support/sql/streaming")
+ other_df = self.spark.range(100)
+ q = df.writeStream.foreachBatch(curried_function(other_df)).start()
+ q.processAllAvailable()
+ collected = self.spark.sql("select * from batch_updates").collect()
+ self.assertTrue(len(collected), 2)
+ self.assertEqual(100, self.spark.sql("select * from
updates").count())
+ finally:
+ if q:
+ q.stop()
+
+ def test_pickling_error(self):
+ class NoPickle:
+ def __reduce__(self):
+ raise ValueError("No pickle")
+
+ no_pickle = NoPickle()
+
+ def func(df, _):
+ print(no_pickle)
+ df.count()
+
+ with self.assertRaises(PySparkPicklingError):
+ df =
self.spark.readStream.format("text").load("python/test_support/sql/streaming")
+ q = df.writeStream.foreachBatch(func).start()
+ q.processAllAvailable()
+
def test_accessing_spark_session(self):
spark = self.spark
def func(df, _):
- spark.createDataFrame([("do", "not"), ("serialize",
"spark")]).collect()
+ spark.createDataFrame([("you", "can"), ("serialize",
"spark")]).createOrReplaceTempView(
+ "test_accessing_spark_session"
+ )
- error_thrown = False
try:
-
self.spark.readStream.format("rate").load().writeStream.foreachBatch(func).start()
- except PySparkPicklingError as e:
- self.assertEqual(e.getErrorClass(),
"STREAMING_CONNECT_SERIALIZATION_ERROR")
- error_thrown = True
- self.assertTrue(error_thrown)
+ df =
self.spark.readStream.format("text").load("python/test_support/sql/streaming")
+ q = df.writeStream.foreachBatch(func).start()
+ q.processAllAvailable()
+ self.assertEqual(2,
spark.table("test_accessing_spark_session").count())
+ finally:
+ if q:
+ q.stop()
def test_accessing_spark_session_through_df(self):
- dataframe = self.spark.createDataFrame([("do", "not"), ("serialize",
"dataframe")])
+ dataframe = self.spark.createDataFrame([("you", "can"), ("serialize",
"dataframe")])
def func(df, _):
- dataframe.collect()
+
dataframe.createOrReplaceTempView("test_accessing_spark_session_through_df")
- error_thrown = False
try:
-
self.spark.readStream.format("rate").load().writeStream.foreachBatch(func).start()
- except PySparkPicklingError as e:
- self.assertEqual(e.getErrorClass(),
"STREAMING_CONNECT_SERIALIZATION_ERROR")
- error_thrown = True
- self.assertTrue(error_thrown)
+ df =
self.spark.readStream.format("text").load("python/test_support/sql/streaming")
+ q = df.writeStream.foreachBatch(func).start()
+ q.processAllAvailable()
+ self.assertEqual(2,
self.spark.table("test_accessing_spark_session_through_df").count())
+ finally:
+ if q:
+ q.stop()
if __name__ == "__main__":
diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
index f5ffa0154df1..a15e4547f67a 100644
--- a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
+++ b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
@@ -19,7 +19,6 @@ import unittest
import time
import pyspark.cloudpickle
-from pyspark.errors import PySparkPicklingError
from pyspark.sql.tests.streaming.test_streaming_listener import
StreamingListenerTestsMixin
from pyspark.sql.streaming.listener import StreamingQueryListener
from pyspark.sql.functions import count, lit
@@ -138,7 +137,9 @@ class
StreamingListenerParityTests(StreamingListenerTestsMixin, ReusedConnectTes
class TestListener(StreamingQueryListener):
def onQueryStarted(self, event):
- spark.createDataFrame([("do", "not"), ("serialize",
"spark")]).collect()
+ spark.createDataFrame(
+ [("you", "can"), ("serialize", "spark")]
+ ).createOrReplaceTempView("test_accessing_spark_session")
def onQueryProgress(self, event):
pass
@@ -149,16 +150,10 @@ class
StreamingListenerParityTests(StreamingListenerTestsMixin, ReusedConnectTes
def onQueryTerminated(self, event):
pass
- error_thrown = False
- try:
- self.spark.streams.addListener(TestListener())
- except PySparkPicklingError as e:
- self.assertEqual(e.getErrorClass(),
"STREAMING_CONNECT_SERIALIZATION_ERROR")
- error_thrown = True
- self.assertTrue(error_thrown)
+ self.spark.streams.addListener(TestListener())
def test_accessing_spark_session_through_df(self):
- dataframe = self.spark.createDataFrame([("do", "not"), ("serialize",
"dataframe")])
+ dataframe = self.spark.createDataFrame([("you", "can"), ("serialize",
"dataframe")])
class TestListener(StreamingQueryListener):
def onQueryStarted(self, event):
@@ -173,13 +168,7 @@ class
StreamingListenerParityTests(StreamingListenerTestsMixin, ReusedConnectTes
def onQueryTerminated(self, event):
pass
- error_thrown = False
- try:
- self.spark.streams.addListener(TestListener())
- except PySparkPicklingError as e:
- self.assertEqual(e.getErrorClass(),
"STREAMING_CONNECT_SERIALIZATION_ERROR")
- error_thrown = True
- self.assertTrue(error_thrown)
+ self.spark.streams.addListener(TestListener())
if __name__ == "__main__":
diff --git a/python/pyspark/sql/tests/connect/test_parity_udtf.py
b/python/pyspark/sql/tests/connect/test_parity_udtf.py
index 02570ac9efa7..5071b69060a1 100644
--- a/python/pyspark/sql/tests/connect/test_parity_udtf.py
+++ b/python/pyspark/sql/tests/connect/test_parity_udtf.py
@@ -28,7 +28,7 @@ if should_test_connect:
from pyspark.util import is_remote_only
from pyspark.sql.tests.test_udtf import BaseUDTFTestsMixin, UDTFArrowTestsMixin
from pyspark.testing.connectutils import ReusedConnectTestCase
-from pyspark.errors.exceptions.connect import SparkConnectGrpcException
+from pyspark.errors.exceptions.connect import SparkConnectGrpcException,
PythonException
class UDTFParityTests(BaseUDTFTestsMixin, ReusedConnectTestCase):
@@ -76,6 +76,10 @@ class UDTFParityTests(BaseUDTFTestsMixin,
ReusedConnectTestCase):
def test_udtf_with_analyze_using_file(self):
super().test_udtf_with_analyze_using_file()
+ @unittest.skip("pyspark-connect can serialize SparkSession, but fails on
executor")
+ def test_udtf_access_spark_session(self):
+ super().test_udtf_access_spark_session()
+
def _add_pyfile(self, path):
self.spark.addArtifacts(path, pyfile=True)
@@ -99,6 +103,18 @@ class ArrowUDTFParityTests(UDTFArrowTestsMixin,
UDTFParityTests):
finally:
super(ArrowUDTFParityTests, cls).tearDownClass()
+ def test_udtf_access_spark_session_connect(self):
+ df = self.spark.range(10)
+
+ @udtf(returnType="x: int")
+ class TestUDTF:
+ def eval(self):
+ df.collect()
+ yield 1,
+
+ with self.assertRaisesRegex(PythonException, "NO_ACTIVE_SESSION"):
+ TestUDTF().collect()
+
if __name__ == "__main__":
import unittest
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]