This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new bfe231c9a5a8 [SPARK-52539][CONNECT] Introduce session hooks
bfe231c9a5a8 is described below
commit bfe231c9a5a8bcb1fd175ba59e4e5d5144f56fb0
Author: Niklas Mohrin <[email protected]>
AuthorDate: Mon Jun 23 09:30:27 2025 +0900
[SPARK-52539][CONNECT] Introduce session hooks
### What changes were proposed in this pull request?
This PR introduces the concept of Hooks in the Spark connect session. These
hooks are used to register code that should run right before an
ExecutePlanRequest is sent and possibly modify it. In the future, we might also
consider adding more methods to the Hook type that should be called in other
places.
### Why are the changes needed?
The motivation for this change is to allow for injecting additional
behavior into spark sessions without having to touch the code of the session
itself. The Hook interface is abstract and powerful enough to implement
complicated behavior while maintaining a clear separation of concerns.
### Does this PR introduce _any_ user-facing change?
In this PR, the session builder has a new method `_registerHook`. The
method is private for now so that we can first experiment with the interface.
### How was this patch tested?
I added unit tests that check that a registered hook is created and called
as expected.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #51233 from niklasmohrin/hooks.
Authored-by: Niklas Mohrin <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/sql/connect/client/core.py | 11 ++++++
python/pyspark/sql/connect/session.py | 45 +++++++++++++++++++---
.../sql/tests/connect/client/test_client.py | 33 ++++++++++++++++
3 files changed, 84 insertions(+), 5 deletions(-)
diff --git a/python/pyspark/sql/connect/client/core.py
b/python/pyspark/sql/connect/client/core.py
index 3cfb38fdfa7d..122372877a7b 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -109,6 +109,7 @@ from pyspark.sql.connect.shell.progress import Progress,
ProgressHandler, from_p
if TYPE_CHECKING:
from google.rpc.error_details_pb2 import ErrorInfo
from pyspark.sql.connect._typing import DataTypeOrString
+ from pyspark.sql.connect.session import SparkSession
from pyspark.sql.datasource import DataSource
@@ -606,6 +607,7 @@ class SparkConnectClient(object):
channel_options: Optional[List[Tuple[str, Any]]] = None,
retry_policy: Optional[Dict[str, Any]] = None,
use_reattachable_execute: bool = True,
+ session_hooks: Optional[list["SparkSession.Hook"]] = None,
):
"""
Creates a new SparkSession for the Spark Connect interface.
@@ -636,6 +638,8 @@ class SparkConnectClient(object):
a failed request. Default: 60000(ms).
use_reattachable_execute: bool
Enable reattachable execution.
+ session_hooks: list[SparkSession.Hook], optional
+ List of session hooks to call.
"""
self.thread_local = threading.local()
@@ -675,6 +679,7 @@ class SparkConnectClient(object):
self._user_id, self._session_id, self._channel,
self._builder.metadata()
)
self._use_reattachable_execute = use_reattachable_execute
+ self._session_hooks = session_hooks or []
# Configure logging for the SparkConnect client.
# Capture the server-side session ID and set it to None initially. It
will
@@ -1365,6 +1370,9 @@ class SparkConnectClient(object):
"""
logger.debug("Execute")
+ for hook in self._session_hooks:
+ req = hook.on_execute_plan(req)
+
def handle_response(b: pb2.ExecutePlanResponse) -> None:
self._verify_response_integrity(b)
@@ -1406,6 +1414,9 @@ class SparkConnectClient(object):
# when not at debug log level.
logger.debug(f"ExecuteAndFetchAsIterator. Request:
{self._proto_to_string(req)}")
+ for hook in self._session_hooks:
+ req = hook.on_execute_plan(req)
+
num_records = 0
def handle_response(
diff --git a/python/pyspark/sql/connect/session.py
b/python/pyspark/sql/connect/session.py
index 303b9c9aac12..fe923d471060 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -23,7 +23,7 @@ import json
import threading
import os
import warnings
-from collections.abc import Sized
+from collections.abc import Callable, Sized
import functools
from threading import RLock
from typing import (
@@ -33,6 +33,7 @@ from typing import (
Dict,
List,
Tuple,
+ TypeAlias,
Set,
cast,
overload,
@@ -106,6 +107,7 @@ from pyspark.errors import (
)
if TYPE_CHECKING:
+ import pyspark.sql.connect.proto as pb2
from pyspark.sql.connect._typing import OptionalPrimitiveType
from pyspark.sql.connect.catalog import Catalog
from pyspark.sql.connect.udf import UDFRegistration
@@ -130,6 +132,7 @@ class SparkSession:
def __init__(self) -> None:
self._options: Dict[str, Any] = {}
self._channel_builder: Optional[DefaultChannelBuilder] = None
+ self._hook_factories: list["SparkSession.HookFactory"] = []
@overload
def config(self, key: str, value: Any) -> "SparkSession.Builder":
@@ -191,6 +194,11 @@ class SparkSession:
self._channel_builder = channelBuilder
return self
+ def _registerHook(self, hook_factory: "SparkSession.HookFactory") ->
"SparkSession.Builder":
+ with self._lock:
+ self._hook_factories.append(hook_factory)
+ return self
+
def enableHiveSupport(self) -> "SparkSession.Builder":
raise PySparkNotImplementedError(
errorClass="NOT_IMPLEMENTED", messageParameters={"feature":
"enableHiveSupport"}
@@ -235,11 +243,13 @@ class SparkSession:
if has_channel_builder:
assert self._channel_builder is not None
- session = SparkSession(connection=self._channel_builder)
+ session = SparkSession(
+ connection=self._channel_builder,
hook_factories=self._hook_factories
+ )
else:
spark_remote = to_str(self._options.get("spark.remote"))
assert spark_remote is not None
- session = SparkSession(connection=spark_remote)
+ session = SparkSession(connection=spark_remote,
hook_factories=self._hook_factories)
SparkSession._set_default_and_active_session(session)
self._apply_options(session)
@@ -255,6 +265,19 @@ class SparkSession:
self._apply_options(session)
return session
+ class Hook:
+ """A Hook can be used to inject behavior into the session."""
+
+ def on_execute_plan(self, request: "pb2.ExecutePlanRequest") ->
"pb2.ExecutePlanRequest":
+ """Called before sending an ExecutePlanRequest.
+
+ The request is replaced with the one returned by this method.
+ """
+ return request
+
+ HookFactory: TypeAlias = Callable[["SparkSession"], Hook]
+ """A function that, given a session, returns a hook set up for it."""
+
_client: SparkConnectClient
# SPARK-47544: Explicitly declaring this as an identifier instead of a
method.
@@ -262,7 +285,12 @@ class SparkSession:
builder: Builder = classproperty(lambda cls: cls.Builder()) # type: ignore
builder.__doc__ = PySparkSession.builder.__doc__
- def __init__(self, connection: Union[str, DefaultChannelBuilder], userId:
Optional[str] = None):
+ def __init__(
+ self,
+ connection: Union[str, DefaultChannelBuilder],
+ userId: Optional[str] = None,
+ hook_factories: Optional[list[HookFactory]] = None,
+ ) -> None:
"""
Creates a new SparkSession for the Spark Connect interface.
@@ -277,8 +305,15 @@ class SparkSession:
isolate their Spark Sessions. If the `user_id` is not set, will
default to
the $USER environment. Defining the user ID as part of the
connection string
takes precedence.
+ hook_factories: list[HookFactory], optional
+ Optional list of hook factories for hooks that should be
registered for this session.
"""
- self._client = SparkConnectClient(connection=connection,
user_id=userId)
+ hook_factories = hook_factories or []
+ self._client = SparkConnectClient(
+ connection=connection,
+ user_id=userId,
+ session_hooks=[factory(self) for factory in hook_factories],
+ )
self._session_id = self._client._session_id
# Set to false to prevent client.release_session on close() (testing
only)
diff --git a/python/pyspark/sql/tests/connect/client/test_client.py
b/python/pyspark/sql/tests/connect/client/test_client.py
index 647b950fd20f..43094b0e7e02 100644
--- a/python/pyspark/sql/tests/connect/client/test_client.py
+++ b/python/pyspark/sql/tests/connect/client/test_client.py
@@ -34,6 +34,7 @@ if should_test_connect:
DefaultPolicy,
)
from pyspark.sql.connect.client.reattach import
ExecutePlanResponseReattachableIterator
+ from pyspark.sql.connect.session import SparkSession as RemoteSparkSession
from pyspark.errors import PySparkRuntimeError, RetriesExceeded
import pyspark.sql.connect.proto as proto
@@ -261,6 +262,38 @@ class SparkConnectClientTestCase(unittest.TestCase):
client = SparkConnectClient(chan)
self.assertEqual(client._session_id, chan.session_id)
+ def test_session_hook(self):
+ inits = 0
+ calls = 0
+
+ class TestHook(RemoteSparkSession.Hook):
+ def __init__(self, _session):
+ nonlocal inits
+ inits += 1
+
+ def on_execute_plan(self, req):
+ nonlocal calls
+ calls += 1
+ return req
+
+ session = (
+
RemoteSparkSession.builder.remote("sc://foo")._registerHook(TestHook).getOrCreate()
+ )
+ self.assertEqual(inits, 1)
+ self.assertEqual(calls, 0)
+ session.client._stub = MockService(session.client._session_id)
+ session.client.disable_reattachable_execute()
+
+ # Called from _execute_and_fetch_as_iterator
+ session.range(1).collect()
+ self.assertEqual(inits, 1)
+ self.assertEqual(calls, 1)
+
+ # Called from _execute
+ session.udf.register("test_func", lambda x: x + 1)
+ self.assertEqual(inits, 1)
+ self.assertEqual(calls, 2)
+
def test_custom_operation_id(self):
client = SparkConnectClient("sc://foo/;token=bar",
use_reattachable_execute=False)
mock = MockService(client._session_id)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]