This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new bfe231c9a5a8 [SPARK-52539][CONNECT] Introduce session hooks
bfe231c9a5a8 is described below

commit bfe231c9a5a8bcb1fd175ba59e4e5d5144f56fb0
Author: Niklas Mohrin <[email protected]>
AuthorDate: Mon Jun 23 09:30:27 2025 +0900

    [SPARK-52539][CONNECT] Introduce session hooks
    
    ### What changes were proposed in this pull request?
    
    This PR introduces the concept of Hooks in the Spark connect session. These 
hooks are used to register code that should run right before an 
ExecutePlanRequest is sent and possibly modify it. In the future, we might also 
consider adding more methods to the Hook type that should be called in other 
places.
    
    ### Why are the changes needed?
    
    The motivation for this change is to allow for injecting additional 
behavior into spark sessions without having to touch the code of the session 
itself. The Hook interface is abstract and powerful enough to implement 
complicated behavior while maintaining a clear separation of concerns.
    
    ### Does this PR introduce _any_ user-facing change?
    
    In this PR, the session builder has a new method `_registerHook`. The 
method is private for now so that we can first experiment with the interface.
    
    ### How was this patch tested?
    
    I added unit tests that check that a registered hook is created and called 
as expected.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #51233 from niklasmohrin/hooks.
    
    Authored-by: Niklas Mohrin <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/sql/connect/client/core.py          | 11 ++++++
 python/pyspark/sql/connect/session.py              | 45 +++++++++++++++++++---
 .../sql/tests/connect/client/test_client.py        | 33 ++++++++++++++++
 3 files changed, 84 insertions(+), 5 deletions(-)

diff --git a/python/pyspark/sql/connect/client/core.py 
b/python/pyspark/sql/connect/client/core.py
index 3cfb38fdfa7d..122372877a7b 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -109,6 +109,7 @@ from pyspark.sql.connect.shell.progress import Progress, 
ProgressHandler, from_p
 if TYPE_CHECKING:
     from google.rpc.error_details_pb2 import ErrorInfo
     from pyspark.sql.connect._typing import DataTypeOrString
+    from pyspark.sql.connect.session import SparkSession
     from pyspark.sql.datasource import DataSource
 
 
@@ -606,6 +607,7 @@ class SparkConnectClient(object):
         channel_options: Optional[List[Tuple[str, Any]]] = None,
         retry_policy: Optional[Dict[str, Any]] = None,
         use_reattachable_execute: bool = True,
+        session_hooks: Optional[list["SparkSession.Hook"]] = None,
     ):
         """
         Creates a new SparkSession for the Spark Connect interface.
@@ -636,6 +638,8 @@ class SparkConnectClient(object):
                     a failed request. Default: 60000(ms).
         use_reattachable_execute: bool
             Enable reattachable execution.
+        session_hooks: list[SparkSession.Hook], optional
+            List of session hooks to call.
         """
         self.thread_local = threading.local()
 
@@ -675,6 +679,7 @@ class SparkConnectClient(object):
             self._user_id, self._session_id, self._channel, 
self._builder.metadata()
         )
         self._use_reattachable_execute = use_reattachable_execute
+        self._session_hooks = session_hooks or []
         # Configure logging for the SparkConnect client.
 
         # Capture the server-side session ID and set it to None initially. It 
will
@@ -1365,6 +1370,9 @@ class SparkConnectClient(object):
         """
         logger.debug("Execute")
 
+        for hook in self._session_hooks:
+            req = hook.on_execute_plan(req)
+
         def handle_response(b: pb2.ExecutePlanResponse) -> None:
             self._verify_response_integrity(b)
 
@@ -1406,6 +1414,9 @@ class SparkConnectClient(object):
             # when not at debug log level.
             logger.debug(f"ExecuteAndFetchAsIterator. Request: 
{self._proto_to_string(req)}")
 
+        for hook in self._session_hooks:
+            req = hook.on_execute_plan(req)
+
         num_records = 0
 
         def handle_response(
diff --git a/python/pyspark/sql/connect/session.py 
b/python/pyspark/sql/connect/session.py
index 303b9c9aac12..fe923d471060 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -23,7 +23,7 @@ import json
 import threading
 import os
 import warnings
-from collections.abc import Sized
+from collections.abc import Callable, Sized
 import functools
 from threading import RLock
 from typing import (
@@ -33,6 +33,7 @@ from typing import (
     Dict,
     List,
     Tuple,
+    TypeAlias,
     Set,
     cast,
     overload,
@@ -106,6 +107,7 @@ from pyspark.errors import (
 )
 
 if TYPE_CHECKING:
+    import pyspark.sql.connect.proto as pb2
     from pyspark.sql.connect._typing import OptionalPrimitiveType
     from pyspark.sql.connect.catalog import Catalog
     from pyspark.sql.connect.udf import UDFRegistration
@@ -130,6 +132,7 @@ class SparkSession:
         def __init__(self) -> None:
             self._options: Dict[str, Any] = {}
             self._channel_builder: Optional[DefaultChannelBuilder] = None
+            self._hook_factories: list["SparkSession.HookFactory"] = []
 
         @overload
         def config(self, key: str, value: Any) -> "SparkSession.Builder":
@@ -191,6 +194,11 @@ class SparkSession:
                 self._channel_builder = channelBuilder
                 return self
 
+        def _registerHook(self, hook_factory: "SparkSession.HookFactory") -> 
"SparkSession.Builder":
+            with self._lock:
+                self._hook_factories.append(hook_factory)
+                return self
+
         def enableHiveSupport(self) -> "SparkSession.Builder":
             raise PySparkNotImplementedError(
                 errorClass="NOT_IMPLEMENTED", messageParameters={"feature": 
"enableHiveSupport"}
@@ -235,11 +243,13 @@ class SparkSession:
 
             if has_channel_builder:
                 assert self._channel_builder is not None
-                session = SparkSession(connection=self._channel_builder)
+                session = SparkSession(
+                    connection=self._channel_builder, 
hook_factories=self._hook_factories
+                )
             else:
                 spark_remote = to_str(self._options.get("spark.remote"))
                 assert spark_remote is not None
-                session = SparkSession(connection=spark_remote)
+                session = SparkSession(connection=spark_remote, 
hook_factories=self._hook_factories)
 
             SparkSession._set_default_and_active_session(session)
             self._apply_options(session)
@@ -255,6 +265,19 @@ class SparkSession:
                 self._apply_options(session)
                 return session
 
+    class Hook:
+        """A Hook can be used to inject behavior into the session."""
+
+        def on_execute_plan(self, request: "pb2.ExecutePlanRequest") -> 
"pb2.ExecutePlanRequest":
+            """Called before sending an ExecutePlanRequest.
+
+            The request is replaced with the one returned by this method.
+            """
+            return request
+
+    HookFactory: TypeAlias = Callable[["SparkSession"], Hook]
+    """A function that, given a session, returns a hook set up for it."""
+
     _client: SparkConnectClient
 
     # SPARK-47544: Explicitly declaring this as an identifier instead of a 
method.
@@ -262,7 +285,12 @@ class SparkSession:
     builder: Builder = classproperty(lambda cls: cls.Builder())  # type: ignore
     builder.__doc__ = PySparkSession.builder.__doc__
 
-    def __init__(self, connection: Union[str, DefaultChannelBuilder], userId: 
Optional[str] = None):
+    def __init__(
+        self,
+        connection: Union[str, DefaultChannelBuilder],
+        userId: Optional[str] = None,
+        hook_factories: Optional[list[HookFactory]] = None,
+    ) -> None:
         """
         Creates a new SparkSession for the Spark Connect interface.
 
@@ -277,8 +305,15 @@ class SparkSession:
             isolate their Spark Sessions. If the `user_id` is not set, will 
default to
             the $USER environment. Defining the user ID as part of the 
connection string
             takes precedence.
+        hook_factories: list[HookFactory], optional
+            Optional list of hook factories for hooks that should be 
registered for this session.
         """
-        self._client = SparkConnectClient(connection=connection, 
user_id=userId)
+        hook_factories = hook_factories or []
+        self._client = SparkConnectClient(
+            connection=connection,
+            user_id=userId,
+            session_hooks=[factory(self) for factory in hook_factories],
+        )
         self._session_id = self._client._session_id
 
         # Set to false to prevent client.release_session on close() (testing 
only)
diff --git a/python/pyspark/sql/tests/connect/client/test_client.py 
b/python/pyspark/sql/tests/connect/client/test_client.py
index 647b950fd20f..43094b0e7e02 100644
--- a/python/pyspark/sql/tests/connect/client/test_client.py
+++ b/python/pyspark/sql/tests/connect/client/test_client.py
@@ -34,6 +34,7 @@ if should_test_connect:
         DefaultPolicy,
     )
     from pyspark.sql.connect.client.reattach import 
ExecutePlanResponseReattachableIterator
+    from pyspark.sql.connect.session import SparkSession as RemoteSparkSession
     from pyspark.errors import PySparkRuntimeError, RetriesExceeded
     import pyspark.sql.connect.proto as proto
 
@@ -261,6 +262,38 @@ class SparkConnectClientTestCase(unittest.TestCase):
         client = SparkConnectClient(chan)
         self.assertEqual(client._session_id, chan.session_id)
 
+    def test_session_hook(self):
+        inits = 0
+        calls = 0
+
+        class TestHook(RemoteSparkSession.Hook):
+            def __init__(self, _session):
+                nonlocal inits
+                inits += 1
+
+            def on_execute_plan(self, req):
+                nonlocal calls
+                calls += 1
+                return req
+
+        session = (
+            
RemoteSparkSession.builder.remote("sc://foo")._registerHook(TestHook).getOrCreate()
+        )
+        self.assertEqual(inits, 1)
+        self.assertEqual(calls, 0)
+        session.client._stub = MockService(session.client._session_id)
+        session.client.disable_reattachable_execute()
+
+        # Called from _execute_and_fetch_as_iterator
+        session.range(1).collect()
+        self.assertEqual(inits, 1)
+        self.assertEqual(calls, 1)
+
+        # Called from _execute
+        session.udf.register("test_func", lambda x: x + 1)
+        self.assertEqual(inits, 1)
+        self.assertEqual(calls, 2)
+
     def test_custom_operation_id(self):
         client = SparkConnectClient("sc://foo/;token=bar", 
use_reattachable_execute=False)
         mock = MockService(client._session_id)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to