This is an automated email from the ASF dual-hosted git repository.
hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 264ca4dc320e [SPARK-53455][CONNECT] Add `CloneSession` RPC
264ca4dc320e is described below
commit 264ca4dc320e516bd88679ed2404f1c15d2f24eb
Author: vicennial <[email protected]>
AuthorDate: Mon Oct 13 11:01:58 2025 -0400
[SPARK-53455][CONNECT] Add `CloneSession` RPC
### What changes were proposed in this pull request?
Adds a new experimental/developer RPC `CloneSession` to the
`SparkConnectService`.
✅ CLONED (from SparkSession)
- SessionState - SQL configs, temp views, UDFs, catalog metadata
- ArtifactManager - JARs, files, classes added to session
- ManagedJobTags - Job group tags for tracking
- SharedState (reference) - Metastore, global temp views
- SparkContext (reference) - Core Spark engine
❌ NOT CLONED (SessionHolder - Spark Connect layer)
- planCache - (Partially analyzed) Logical plans for query optimization
- operationIds - Currently executing operations
- errorIdToError - Recent errors for debugging
- eventManager - Session lifecycle events
- dataFrameCache - DataFrames for foreachBatch callbacks
- mlCache - ML models and pipelines
- listenerCache - Streaming query listeners
- pipelineExecutions - Active pipeline contexts
- dataflowGraphRegistry - Registered dataflow graphs
- streamingForeachBatchRunnerCleanerCache - Python streaming workers
- pythonAccumulator - Python metrics collection
- Session timings - Start time, last access, custom timeout
The clone preserves all SQL/catalog state but creates a fresh runtime
environment. An analogy is cloning a database schema/config but not the active
connections, caches, or running jobs.
### Why are the changes needed?
Spark Connect introduced the concept of resource isolation (via
`ArtifactManager`, which has been ported to classic Spark) and thus,
jars/pyfiles/artifacts added to each session are isolated from other sessions.
A slight rough edge is that if a user wishes to fork the state of a session
but maintain independence, the only possible way is to create a new session and
reupload/reinit all base jars/artifacts/pyfiles, etc.
Support for cloning through the API helps address the rough edge while
maintaining all the benefits of session resource isolation.
### Does this PR introduce _any_ user-facing change?
Yes
```python
spark = SparkSession.builder.remote("sc://localhost").getOrCreate()
spark.conf.set("my.custom.config", "value")
spark.addArtifact("/path/to/my.jar")
spark.sql("CREATE TEMP VIEW my_view AS SELECT 1 AS id")
# Clone the session
cloned_spark = spark.cloneSession()
# The cloned session has all the same state
assert cloned_spark.conf.get("my.custom.config") == "value"
assert cloned_spark.sql("SELECT * FROM my_view").collect() == [Row(id=1)]
# But operations are isolated between sessions
cloned_spark.sql("DROP VIEW my_view") # Only affects cloned session
spark.sql("SELECT * FROM my_view").collect() # Original still works
```
### How was this patch tested?
New individual unit tests along with new test suites.
### Was this patch authored or co-authored using generative AI tooling?
Co-authored with assistance from Claude Code.
Closes #52200 from vicennial/cloneAPI.
Authored-by: vicennial <[email protected]>
Signed-off-by: Herman van Hovell <[email protected]>
---
.../src/main/resources/error/error-conditions.json | 23 ++
dev/sparktestsupport/modules.py | 1 +
python/pyspark/sql/connect/client/core.py | 66 +++++
python/pyspark/sql/connect/proto/base_pb2.py | 10 +-
python/pyspark/sql/connect/proto/base_pb2.pyi | 149 +++++++++++
python/pyspark/sql/connect/proto/base_pb2_grpc.py | 55 ++++
python/pyspark/sql/connect/session.py | 34 +++
.../tests/connect/test_connect_clone_session.py | 153 +++++++++++
.../sql/tests/test_connect_compatibility.py | 2 +-
.../spark/sql/connect/SparkSessionCloneSuite.scala | 64 +++++
.../src/main/protobuf/spark/connect/base.proto | 58 +++++
.../apache/spark/sql/connect/SparkSession.scala | 33 +++
.../client/CustomSparkConnectBlockingStub.scala | 13 +
.../sql/connect/client/SparkConnectClient.scala | 51 ++++
.../service/SparkConnectCloneSessionHandler.scala | 62 +++++
.../sql/connect/service/SparkConnectService.scala | 14 ++
.../service/SparkConnectSessionManager.scala | 67 +++++
.../service/SparkConnectCloneSessionSuite.scala | 279 +++++++++++++++++++++
18 files changed, 1130 insertions(+), 4 deletions(-)
diff --git a/common/utils/src/main/resources/error/error-conditions.json
b/common/utils/src/main/resources/error/error-conditions.json
index 8f6687587f78..afe56f6db2fd 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -2507,6 +2507,29 @@
],
"sqlState" : "22P03"
},
+ "INVALID_CLONE_SESSION_REQUEST" : {
+ "message" : [
+ "Invalid session clone request."
+ ],
+ "subClass" : {
+ "TARGET_SESSION_ID_ALREADY_CLOSED" : {
+ "message" : [
+ "Cannot clone session to target session ID <targetSessionId> because
a session with this ID was previously closed."
+ ]
+ },
+ "TARGET_SESSION_ID_ALREADY_EXISTS" : {
+ "message" : [
+ "Cannot clone session to target session ID <targetSessionId> because
a session with this ID already exists."
+ ]
+ },
+ "TARGET_SESSION_ID_FORMAT" : {
+ "message" : [
+ "Target session ID <targetSessionId> for clone operation must be an
UUID string of the format '00112233-4455-6677-8899-aabbccddeeff'."
+ ]
+ }
+ },
+ "sqlState" : "42K04"
+ },
"INVALID_COLUMN_NAME_AS_PATH" : {
"message" : [
"The datasource <datasource> cannot save the column <columnName> because
its name contains some characters that are not allowed in file paths. Please,
use an alias to rename it."
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 8aab600071dc..945a2ac9189b 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -1095,6 +1095,7 @@ pyspark_connect = Module(
"pyspark.sql.tests.connect.test_connect_basic",
"pyspark.sql.tests.connect.test_connect_dataframe_property",
"pyspark.sql.tests.connect.test_connect_channel",
+ "pyspark.sql.tests.connect.test_connect_clone_session",
"pyspark.sql.tests.connect.test_connect_error",
"pyspark.sql.tests.connect.test_connect_function",
"pyspark.sql.tests.connect.test_connect_collection",
diff --git a/python/pyspark/sql/connect/client/core.py
b/python/pyspark/sql/connect/client/core.py
index 741d612f53f4..d0d191dbd7fd 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -2105,3 +2105,69 @@ class SparkConnectClient(object):
ml_command_result = properties["ml_command_result"]
return ml_command_result.param.long
+
+ def clone(self, new_session_id: Optional[str] = None) ->
"SparkConnectClient":
+ """
+ Clone this client session on the server side. The server-side session
is cloned with
+ all its current state (SQL configurations, temporary views, registered
functions,
+ catalog state) copied over to a new independent session. The returned
client with the
+ cloned session is isolated from this client's session - any subsequent
changes to
+ either session's server-side state will not be reflected in the other.
+
+ Parameters
+ ----------
+ new_session_id : str, optional
+ Custom session ID to use for the cloned session (must be a valid
UUID).
+ If not provided, a new UUID will be generated.
+
+ Returns
+ -------
+ SparkConnectClient
+ A new SparkConnectClient instance with the cloned session.
+
+ Notes
+ -----
+ This creates a new server-side session with the specified or generated
session ID
+ while preserving the current session's configuration and state.
+
+ .. note::
+ This is a developer API.
+ """
+ from pyspark.sql.connect.proto import base_pb2 as pb2
+
+ request = pb2.CloneSessionRequest(
+ session_id=self._session_id,
+ client_type="python",
+ )
+ if self._user_id is not None:
+ request.user_context.user_id = self._user_id
+
+ if new_session_id is not None:
+ request.new_session_id = new_session_id
+
+ for attempt in self._retrying():
+ with attempt:
+ response: pb2.CloneSessionResponse = self._stub.CloneSession(
+ request, metadata=self._builder.metadata()
+ )
+
+ # Assert that the returned session ID matches the requested ID if one
was provided
+ if new_session_id is not None:
+ assert response.new_session_id == new_session_id, (
+ f"Returned session ID '{response.new_session_id}' does not
match "
+ f"requested ID '{new_session_id}'"
+ )
+
+ # Create a new client with the cloned session ID
+ new_connection = copy.deepcopy(self._builder)
+ new_connection.set(ChannelBuilder.PARAM_SESSION_ID,
response.new_session_id)
+
+ # Create new client and explicitly set the session ID
+ new_client = SparkConnectClient(
+ connection=new_connection,
+ user_id=self._user_id,
+ use_reattachable_execute=self._use_reattachable_execute,
+ )
+ # Ensure the session ID is correctly set from the response
+ new_client._session_id = response.new_session_id
+ return new_client
diff --git a/python/pyspark/sql/connect/proto/base_pb2.py
b/python/pyspark/sql/connect/proto/base_pb2.py
index 269f5eefbd26..069e0830fcea 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.py
+++ b/python/pyspark/sql/connect/proto/base_pb2.py
@@ -45,7 +45,7 @@ from pyspark.sql.connect.proto import pipelines_pb2 as
spark_dot_connect_dot_pip
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto\x1a\x16spark/connect/ml.proto\x1a\x1dspark/connect/pipelines.proto"t\n\x04Plan\x12-\n\x04root\x18\x01
\x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02
\x01(\x0b\x32\x16.spark.connect.Com [...]
+
b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto\x1a\x16spark/connect/ml.proto\x1a\x1dspark/connect/pipelines.proto"t\n\x04Plan\x12-\n\x04root\x18\x01
\x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02
\x01(\x0b\x32\x16.spark.connect.Com [...]
)
_globals = globals()
@@ -260,6 +260,10 @@ if not _descriptor._USE_C_DESCRIPTORS:
_globals["_FETCHERRORDETAILSRESPONSE_ERROR"]._serialized_end = 17518
_globals["_CHECKPOINTCOMMANDRESULT"]._serialized_start = 17539
_globals["_CHECKPOINTCOMMANDRESULT"]._serialized_end = 17629
- _globals["_SPARKCONNECTSERVICE"]._serialized_start = 17632
- _globals["_SPARKCONNECTSERVICE"]._serialized_end = 18578
+ _globals["_CLONESESSIONREQUEST"]._serialized_start = 17632
+ _globals["_CLONESESSIONREQUEST"]._serialized_end = 17994
+ _globals["_CLONESESSIONRESPONSE"]._serialized_start = 17997
+ _globals["_CLONESESSIONRESPONSE"]._serialized_end = 18201
+ _globals["_SPARKCONNECTSERVICE"]._serialized_start = 18204
+ _globals["_SPARKCONNECTSERVICE"]._serialized_end = 19241
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi
b/python/pyspark/sql/connect/proto/base_pb2.pyi
index 3361649d9323..dc3099ecdffc 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/base_pb2.pyi
@@ -4081,3 +4081,152 @@ class
CheckpointCommandResult(google.protobuf.message.Message):
) -> None: ...
global___CheckpointCommandResult = CheckpointCommandResult
+
+class CloneSessionRequest(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ SESSION_ID_FIELD_NUMBER: builtins.int
+ CLIENT_OBSERVED_SERVER_SIDE_SESSION_ID_FIELD_NUMBER: builtins.int
+ USER_CONTEXT_FIELD_NUMBER: builtins.int
+ CLIENT_TYPE_FIELD_NUMBER: builtins.int
+ NEW_SESSION_ID_FIELD_NUMBER: builtins.int
+ session_id: builtins.str
+ """(Required)
+
+ The session_id specifies a spark session for a user id (which is specified
+ by user_context.user_id). The session_id is set by the client to be able to
+ collate streaming responses from different queries within the dedicated
session.
+ The id should be an UUID string of the format
`00112233-4455-6677-8899-aabbccddeeff`
+ """
+ client_observed_server_side_session_id: builtins.str
+ """(Optional)
+
+ Server-side generated idempotency key from the previous responses (if
any). Server
+ can use this to validate that the server side session has not changed.
+ """
+ @property
+ def user_context(self) -> global___UserContext:
+ """(Required) User context
+
+ user_context.user_id and session_id both identify a unique remote
spark session on the
+ server side.
+ """
+ client_type: builtins.str
+ """Provides optional information about the client sending the request.
This field
+ can be used for language or version specific information and is only
intended for
+ logging purposes and will not be interpreted by the server.
+ """
+ new_session_id: builtins.str
+ """(Optional)
+ The session_id for the new cloned session. If not provided, a new UUID
will be generated.
+ The id should be an UUID string of the format
`00112233-4455-6677-8899-aabbccddeeff`
+ """
+ def __init__(
+ self,
+ *,
+ session_id: builtins.str = ...,
+ client_observed_server_side_session_id: builtins.str | None = ...,
+ user_context: global___UserContext | None = ...,
+ client_type: builtins.str | None = ...,
+ new_session_id: builtins.str | None = ...,
+ ) -> None: ...
+ def HasField(
+ self,
+ field_name: typing_extensions.Literal[
+ "_client_observed_server_side_session_id",
+ b"_client_observed_server_side_session_id",
+ "_client_type",
+ b"_client_type",
+ "_new_session_id",
+ b"_new_session_id",
+ "client_observed_server_side_session_id",
+ b"client_observed_server_side_session_id",
+ "client_type",
+ b"client_type",
+ "new_session_id",
+ b"new_session_id",
+ "user_context",
+ b"user_context",
+ ],
+ ) -> builtins.bool: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "_client_observed_server_side_session_id",
+ b"_client_observed_server_side_session_id",
+ "_client_type",
+ b"_client_type",
+ "_new_session_id",
+ b"_new_session_id",
+ "client_observed_server_side_session_id",
+ b"client_observed_server_side_session_id",
+ "client_type",
+ b"client_type",
+ "new_session_id",
+ b"new_session_id",
+ "session_id",
+ b"session_id",
+ "user_context",
+ b"user_context",
+ ],
+ ) -> None: ...
+ @typing.overload
+ def WhichOneof(
+ self,
+ oneof_group: typing_extensions.Literal[
+ "_client_observed_server_side_session_id",
b"_client_observed_server_side_session_id"
+ ],
+ ) -> typing_extensions.Literal["client_observed_server_side_session_id"] |
None: ...
+ @typing.overload
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_client_type",
b"_client_type"]
+ ) -> typing_extensions.Literal["client_type"] | None: ...
+ @typing.overload
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_new_session_id",
b"_new_session_id"]
+ ) -> typing_extensions.Literal["new_session_id"] | None: ...
+
+global___CloneSessionRequest = CloneSessionRequest
+
+class CloneSessionResponse(google.protobuf.message.Message):
+ """Next ID: 5"""
+
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ SESSION_ID_FIELD_NUMBER: builtins.int
+ SERVER_SIDE_SESSION_ID_FIELD_NUMBER: builtins.int
+ NEW_SESSION_ID_FIELD_NUMBER: builtins.int
+ NEW_SERVER_SIDE_SESSION_ID_FIELD_NUMBER: builtins.int
+ session_id: builtins.str
+ """Session id of the original session that was cloned."""
+ server_side_session_id: builtins.str
+ """Server-side generated idempotency key that the client can use to assert
that the server side
+ session (parent of the cloned session) has not changed.
+ """
+ new_session_id: builtins.str
+ """Session id of the new cloned session."""
+ new_server_side_session_id: builtins.str
+ """Server-side session ID of the new cloned session."""
+ def __init__(
+ self,
+ *,
+ session_id: builtins.str = ...,
+ server_side_session_id: builtins.str = ...,
+ new_session_id: builtins.str = ...,
+ new_server_side_session_id: builtins.str = ...,
+ ) -> None: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "new_server_side_session_id",
+ b"new_server_side_session_id",
+ "new_session_id",
+ b"new_session_id",
+ "server_side_session_id",
+ b"server_side_session_id",
+ "session_id",
+ b"session_id",
+ ],
+ ) -> None: ...
+
+global___CloneSessionResponse = CloneSessionResponse
diff --git a/python/pyspark/sql/connect/proto/base_pb2_grpc.py
b/python/pyspark/sql/connect/proto/base_pb2_grpc.py
index 7501aaf0a3a2..16dd20b563f3 100644
--- a/python/pyspark/sql/connect/proto/base_pb2_grpc.py
+++ b/python/pyspark/sql/connect/proto/base_pb2_grpc.py
@@ -90,6 +90,12 @@ class SparkConnectServiceStub(object):
response_deserializer=spark_dot_connect_dot_base__pb2.FetchErrorDetailsResponse.FromString,
_registered_method=True,
)
+ self.CloneSession = channel.unary_unary(
+ "/spark.connect.SparkConnectService/CloneSession",
+
request_serializer=spark_dot_connect_dot_base__pb2.CloneSessionRequest.SerializeToString,
+
response_deserializer=spark_dot_connect_dot_base__pb2.CloneSessionResponse.FromString,
+ _registered_method=True,
+ )
class SparkConnectServiceServicer(object):
@@ -172,6 +178,20 @@ class SparkConnectServiceServicer(object):
context.set_details("Method not implemented!")
raise NotImplementedError("Method not implemented!")
+ def CloneSession(self, request, context):
+ """Create a clone of a Spark Connect session on the server side. The
server-side session
+ is cloned with all its current state (SQL configurations, temporary
views, registered
+ functions, catalog state) copied over to a new independent session.
The cloned session
+ is isolated from the source session - any subsequent changes to either
session's
+ server-side state will not be reflected in the other.
+
+ The request can optionally specify a custom session ID for the cloned
session (must be
+ a valid UUID). If not provided, a new UUID will be generated
automatically.
+ """
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details("Method not implemented!")
+ raise NotImplementedError("Method not implemented!")
+
def add_SparkConnectServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
@@ -225,6 +245,11 @@ def add_SparkConnectServiceServicer_to_server(servicer,
server):
request_deserializer=spark_dot_connect_dot_base__pb2.FetchErrorDetailsRequest.FromString,
response_serializer=spark_dot_connect_dot_base__pb2.FetchErrorDetailsResponse.SerializeToString,
),
+ "CloneSession": grpc.unary_unary_rpc_method_handler(
+ servicer.CloneSession,
+
request_deserializer=spark_dot_connect_dot_base__pb2.CloneSessionRequest.FromString,
+
response_serializer=spark_dot_connect_dot_base__pb2.CloneSessionResponse.SerializeToString,
+ ),
}
generic_handler = grpc.method_handlers_generic_handler(
"spark.connect.SparkConnectService", rpc_method_handlers
@@ -536,3 +561,33 @@ class SparkConnectService(object):
metadata,
_registered_method=True,
)
+
+ @staticmethod
+ def CloneSession(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
+ return grpc.experimental.unary_unary(
+ request,
+ target,
+ "/spark.connect.SparkConnectService/CloneSession",
+
spark_dot_connect_dot_base__pb2.CloneSessionRequest.SerializeToString,
+ spark_dot_connect_dot_base__pb2.CloneSessionResponse.FromString,
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ _registered_method=True,
+ )
diff --git a/python/pyspark/sql/connect/session.py
b/python/pyspark/sql/connect/session.py
index 6ccffc718d06..f759137fac1d 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -1175,6 +1175,40 @@ class SparkSession:
assert dt is not None
return dt
+ def cloneSession(self, new_session_id: Optional[str] = None) ->
"SparkSession":
+ """
+ Create a clone of this Spark Connect session on the server side. The
server-side session
+ is cloned with all its current state (SQL configurations, temporary
views, registered
+ functions, catalog state) copied over to a new independent session.
The returned cloned
+ session is isolated from this session - any subsequent changes to
either session's
+ server-side state will not be reflected in the other.
+
+ Parameters
+ ----------
+ new_session_id : str, optional
+ Custom session ID to use for the cloned session (must be a valid
UUID).
+ If not provided, a new UUID will be generated.
+
+ Returns
+ -------
+ SparkSession
+ A new SparkSession instance with the cloned session.
+
+ Notes
+ -----
+ This creates a new server-side session with the specified or generated
session ID
+ while preserving the current session's configuration and state.
+
+ .. note::
+ This is a developer API.
+ """
+ cloned_client = self._client.clone(new_session_id)
+ # Create a new SparkSession with the cloned client directly
+ new_session = object.__new__(SparkSession)
+ new_session._client = cloned_client
+ new_session._session_id = cloned_client._session_id
+ return new_session
+
SparkSession.__doc__ = PySparkSession.__doc__
diff --git a/python/pyspark/sql/tests/connect/test_connect_clone_session.py
b/python/pyspark/sql/tests/connect/test_connect_clone_session.py
new file mode 100644
index 000000000000..1cef628a158c
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/test_connect_clone_session.py
@@ -0,0 +1,153 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import uuid
+import unittest
+
+from pyspark.sql.tests.connect.test_connect_basic import
SparkConnectSQLTestCase
+
+
+class SparkConnectCloneSessionTest(SparkConnectSQLTestCase):
+ def test_clone_session_basic(self):
+ """Test basic session cloning functionality."""
+ # Set a configuration in the original session
+ self.connect.sql("SET spark.test.original = 'value1'")
+
+ # Verify the original session has the value
+ original_value = self.connect.sql("SET
spark.test.original").collect()[0][1]
+ self.assertEqual(original_value, "'value1'")
+
+ # Clone the session
+ cloned_session = self.connect.cloneSession()
+
+ # Verify the configuration was copied
+ # (if cloning doesn't preserve dynamic configs, use a different
approach)
+ cloned_value = cloned_session.sql("SET
spark.test.original").collect()[0][1]
+ self.assertEqual(cloned_value, "'value1'")
+ # Verify that sessions are independent by setting different values
+ cloned_session.sql("SET spark.test.original = 'modified_clone'")
+ self.connect.sql("SET spark.test.original = 'modified_original'")
+
+ # Verify independence
+ original_final = self.connect.sql("SET
spark.test.original").collect()[0][1]
+ cloned_final = cloned_session.sql("SET
spark.test.original").collect()[0][1]
+
+ self.assertEqual(original_final, "'modified_original'")
+ self.assertEqual(cloned_final, "'modified_clone'")
+
+ def test_clone_session_with_custom_id(self):
+ """Test cloning session with a custom session ID."""
+ custom_session_id = str(uuid.uuid4())
+ cloned_session = self.connect.cloneSession(custom_session_id)
+
+ self.assertEqual(cloned_session.session_id, custom_session_id)
+ self.assertNotEqual(self.connect.session_id, cloned_session.session_id)
+
+ def test_clone_session_preserves_temp_views(self):
+ """Test that temporary views are preserved in cloned sessions."""
+ # Create a temporary view
+ df = self.connect.createDataFrame([(1, "a"), (2, "b")], ["id",
"value"])
+ df.createOrReplaceTempView("test_table")
+
+ # Verify original session can access the temp view
+ original_count = self.connect.sql("SELECT COUNT(*) FROM
test_table").collect()[0][0]
+ self.assertEqual(original_count, 2)
+
+ # Clone the session
+ cloned_session = self.connect.cloneSession()
+
+ # Create temp view in cloned session for testing
+ df_clone = cloned_session.createDataFrame([(3, "c"), (4, "d")], ["id",
"value"])
+ df_clone.createOrReplaceTempView("test_table")
+
+ # Verify both sessions have independent temp views
+ original_count_final = self.connect.sql("SELECT COUNT(*) FROM
test_table").collect()[0][0]
+ cloned_count_final = cloned_session.sql("SELECT COUNT(*) FROM
test_table").collect()[0][0]
+
+ self.assertEqual(original_count_final, 2) # Original data
+ self.assertEqual(cloned_count_final, 2) # New data
+
+ def test_temp_views_independence_after_cloning(self):
+ """Test that temp views are cloned and then can be modified
independently."""
+ # Create initial temp view before cloning
+ df = self.connect.createDataFrame([(1, "original")], ["id", "type"])
+ df.createOrReplaceTempView("shared_view")
+
+ # Verify original session can access the temp view
+ original_count_before = self.connect.sql("SELECT COUNT(*) FROM
shared_view").collect()[0][0]
+ self.assertEqual(original_count_before, 1)
+
+ # Clone the session - temp views should be preserved
+ cloned_session = self.connect.cloneSession()
+
+ # Verify cloned session can access the same temp view (cloned)
+ cloned_count = cloned_session.sql("SELECT COUNT(*) FROM
shared_view").collect()[0][0]
+ self.assertEqual(cloned_count, 1)
+
+ # Now modify the temp view in each session independently
+ # Replace the view in the original session
+ original_df = self.connect.createDataFrame(
+ [(2, "modified_original"), (3, "another")], ["id", "type"]
+ )
+ original_df.createOrReplaceTempView("shared_view")
+
+ # Replace the view in the cloned session
+ cloned_session.sql(
+ "CREATE OR REPLACE TEMPORARY VIEW shared_view AS SELECT 4 as id,
'cloned_data' as type"
+ )
+
+ # Verify they now have different content (independence after cloning)
+ original_count_after = self.connect.sql("SELECT COUNT(*) FROM
shared_view").collect()[0][0]
+ cloned_count_after = cloned_session.sql("SELECT COUNT(*) FROM
shared_view").collect()[0][0]
+
+ self.assertEqual(original_count_after, 2) # Original session: 2 rows
+ self.assertEqual(cloned_count_after, 1) # Cloned session: 1 row
+
+ def test_invalid_session_id_format(self):
+ """Test that invalid session ID format raises an exception."""
+ with self.assertRaises(Exception) as context:
+ self.connect.cloneSession("not-a-valid-uuid")
+
+ # Verify it contains our clone-specific error message
+ self.assertIn(
+ "INVALID_CLONE_SESSION_REQUEST.TARGET_SESSION_ID_FORMAT",
str(context.exception)
+ )
+ self.assertIn("not-a-valid-uuid", str(context.exception))
+
+ def test_clone_session_auto_generated_id(self):
+ """Test that cloneSession() without arguments generates a valid
UUID."""
+ cloned_session = self.connect.cloneSession()
+
+ # Verify different session IDs
+ self.assertNotEqual(self.connect.session_id, cloned_session.session_id)
+
+ # Verify the new session ID is a valid UUID
+ try:
+ uuid.UUID(cloned_session.session_id)
+ except ValueError:
+ self.fail("Generated session ID is not a valid UUID")
+
+
+if __name__ == "__main__":
+ from pyspark.sql.tests.connect.test_connect_clone_session import * #
noqa: F401
+
+ try:
+ import xmlrunner
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/test_connect_compatibility.py
b/python/pyspark/sql/tests/test_connect_compatibility.py
index b2e0cc6229c4..4ca1c04bf2d0 100644
--- a/python/pyspark/sql/tests/test_connect_compatibility.py
+++ b/python/pyspark/sql/tests/test_connect_compatibility.py
@@ -268,7 +268,7 @@ class ConnectCompatibilityTestsMixin:
"registerProgressHandler",
"removeProgressHandler",
}
- expected_missing_classic_methods = set()
+ expected_missing_classic_methods = {"cloneSession"}
self.check_compatibility(
ClassicSparkSession,
ConnectSparkSession,
diff --git
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SparkSessionCloneSuite.scala
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SparkSessionCloneSuite.scala
new file mode 100644
index 000000000000..97ad21b2675f
--- /dev/null
+++
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SparkSessionCloneSuite.scala
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect
+
+import java.util.UUID
+
+import org.apache.spark.sql.connect.test.{ConnectFunSuite, RemoteSparkSession}
+
+class SparkSessionCloneSuite extends ConnectFunSuite with RemoteSparkSession {
+
+ test("cloneSession() creates independent session with auto-generated ID") {
+ spark.sql("SET spark.test.original = 'value1'")
+
+ val clonedSession = spark.cloneSession()
+
+ // Sessions should have different IDs
+ assert(spark.sessionId !== clonedSession.sessionId)
+
+ // Modify original session setting
+ spark.sql("SET spark.test.original = 'modified_original'")
+
+ // Verify original session has the modified value
+ val originalFinal = spark.sql("SET
spark.test.original").collect().head.getString(1)
+ assert(originalFinal === "'modified_original'")
+
+ // Verify cloned session retains the original value
+ val clonedValue = clonedSession.sql("SET
spark.test.original").collect().head.getString(1)
+ assert(clonedValue === "'value1'")
+ }
+
+ test("cloneSession(sessionId) creates session with specified ID") {
+ val customSessionId = UUID.randomUUID().toString
+
+ val clonedSession = spark.cloneSession(customSessionId)
+
+ // Verify the cloned session has the specified ID
+ assert(clonedSession.sessionId === customSessionId)
+ assert(spark.sessionId !== clonedSession.sessionId)
+ }
+
+ test("invalid session ID format throws exception") {
+ val ex = intercept[org.apache.spark.SparkException] {
+ spark.cloneSession("not-a-valid-uuid")
+ }
+ // Verify it contains our clone-specific error message
+
assert(ex.getMessage.contains("INVALID_CLONE_SESSION_REQUEST.TARGET_SESSION_ID_FORMAT"))
+ assert(ex.getMessage.contains("not-a-valid-uuid"))
+ }
+}
diff --git a/sql/connect/common/src/main/protobuf/spark/connect/base.proto
b/sql/connect/common/src/main/protobuf/spark/connect/base.proto
index f34b169f821d..6e1029bf0a6a 100644
--- a/sql/connect/common/src/main/protobuf/spark/connect/base.proto
+++ b/sql/connect/common/src/main/protobuf/spark/connect/base.proto
@@ -1152,6 +1152,54 @@ message CheckpointCommandResult {
CachedRemoteRelation relation = 1;
}
+message CloneSessionRequest {
+ // (Required)
+ //
+ // The session_id specifies a spark session for a user id (which is specified
+ // by user_context.user_id). The session_id is set by the client to be able
to
+ // collate streaming responses from different queries within the dedicated
session.
+ // The id should be an UUID string of the format
`00112233-4455-6677-8899-aabbccddeeff`
+ string session_id = 1;
+
+ // (Optional)
+ //
+ // Server-side generated idempotency key from the previous responses (if
any). Server
+ // can use this to validate that the server side session has not changed.
+ optional string client_observed_server_side_session_id = 5;
+
+ // (Required) User context
+ //
+ // user_context.user_id and session_id both identify a unique remote spark
session on the
+ // server side.
+ UserContext user_context = 2;
+
+ // Provides optional information about the client sending the request. This
field
+ // can be used for language or version specific information and is only
intended for
+ // logging purposes and will not be interpreted by the server.
+ optional string client_type = 3;
+
+ // (Optional)
+ // The session_id for the new cloned session. If not provided, a new UUID
will be generated.
+ // The id should be an UUID string of the format
`00112233-4455-6677-8899-aabbccddeeff`
+ optional string new_session_id = 4;
+}
+
+// Next ID: 5
+message CloneSessionResponse {
+ // Session id of the original session that was cloned.
+ string session_id = 1;
+
+ // Server-side generated idempotency key that the client can use to assert
that the server side
+ // session (parent of the cloned session) has not changed.
+ string server_side_session_id = 2;
+
+ // Session id of the new cloned session.
+ string new_session_id = 3;
+
+ // Server-side session ID of the new cloned session.
+ string new_server_side_session_id = 4;
+}
+
// Main interface for the SparkConnect service.
service SparkConnectService {
@@ -1196,4 +1244,14 @@ service SparkConnectService {
// FetchErrorDetails retrieves the matched exception with details based on a
provided error id.
rpc FetchErrorDetails(FetchErrorDetailsRequest) returns
(FetchErrorDetailsResponse) {}
+
+ // Create a clone of a Spark Connect session on the server side. The
server-side session
+ // is cloned with all its current state (SQL configurations, temporary
views, registered
+ // functions, catalog state) copied over to a new independent session. The
cloned session
+ // is isolated from the source session - any subsequent changes to either
session's
+ // server-side state will not be reflected in the other.
+ //
+ // The request can optionally specify a custom session ID for the cloned
session (must be
+ // a valid UUID). If not provided, a new UUID will be generated
automatically.
+ rpc CloneSession(CloneSessionRequest) returns (CloneSessionResponse) {}
}
diff --git
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
index 646db83981fe..f7869a8b4dd8 100644
---
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
+++
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
@@ -684,6 +684,39 @@ class SparkSession private[sql] (
}
}
+ /**
+ * Create a clone of this Spark Connect session on the server side. The
server-side session is
+ * cloned with all its current state (SQL configurations, temporary views,
registered functions,
+ * catalog state) copied over to a new independent session. The returned
cloned session is
+ * isolated from this session - any subsequent changes to either session's
server-side state
+ * will not be reflected in the other.
+ *
+ * @note
+ * This creates a new server-side session with a new session ID while
preserving the current
+ * session's configuration and state.
+ */
+ @DeveloperApi
+ def cloneSession(): SparkSession = {
+ SparkSession.builder().client(client.cloneSession()).create()
+ }
+
+ /**
+ * Create a clone of this Spark Connect session on the server side with a
custom session ID. The
+ * server-side session is cloned with all its current state (SQL
configurations, temporary
+ * views, registered functions, catalog state) copied over to a new
independent session with the
+ * specified session ID. The returned cloned session is isolated from this
session.
+ *
+ * @param sessionId
+ * The custom session ID to use for the cloned session (must be a valid
UUID)
+ * @note
+ * This creates a new server-side session with the specified session ID
while preserving the
+ * current session's configuration and state.
+ */
+ @DeveloperApi
+ def cloneSession(sessionId: String): SparkSession = {
+ SparkSession.builder().client(client.clone(sessionId)).create()
+ }
+
override private[sql] def isUsable: Boolean = client.isSessionValid
}
diff --git
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala
index d7867229248b..913f068fcf34 100644
---
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala
+++
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala
@@ -131,4 +131,17 @@ private[connect] class CustomSparkConnectBlockingStub(
}
}
}
+
+ def cloneSession(request: CloneSessionRequest): CloneSessionResponse = {
+ grpcExceptionConverter.convert(
+ request.getSessionId,
+ request.getUserContext,
+ request.getClientType) {
+ retryHandler.retry {
+ stubState.responseValidator.verifyResponse {
+ stub.cloneSession(request)
+ }
+ }
+ }
+ }
}
diff --git
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
index 48f01a8042a6..3c328681dd9a 100644
---
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
+++
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
@@ -29,6 +29,7 @@ import com.google.protobuf.ByteString
import io.grpc._
import org.apache.spark.SparkBuildInfo.{spark_version => SPARK_VERSION}
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.UserContext
import org.apache.spark.sql.connect.common.ProtoUtils
@@ -414,6 +415,56 @@ private[sql] class SparkConnectClient(
.build()
artifactManager.cacheArtifact(localRelation.toByteArray)
}
+
+ /**
+ * Clone this client session, creating a new session with the same
configuration and shared
+ * state as the current session but with independent runtime state.
+ *
+ * @return
+ * A new SparkConnectClient instance with the cloned session.
+ */
+ @DeveloperApi
+ def cloneSession(): SparkConnectClient = {
+ clone(None)
+ }
+
+ /**
+ * Clone this client session with a custom session ID, creating a new
session with the same
+ * configuration and shared state as the current session but with
independent runtime state.
+ *
+ * @param newSessionId
+ * Custom session ID to use for the cloned session (must be a valid UUID).
+ * @return
+ * A new SparkConnectClient instance with the cloned session.
+ */
+ @DeveloperApi
+ def clone(newSessionId: String): SparkConnectClient = {
+ clone(Some(newSessionId))
+ }
+
+ private def clone(newSessionId: Option[String]): SparkConnectClient = {
+ val requestBuilder = proto.CloneSessionRequest
+ .newBuilder()
+ .setUserContext(userContext)
+ .setSessionId(sessionId)
+ .setClientType("scala")
+
+ newSessionId.foreach(requestBuilder.setNewSessionId)
+
+ val response: proto.CloneSessionResponse =
bstub.cloneSession(requestBuilder.build())
+
+ // Assert that the returned session ID matches the requested ID if one was
provided
+ newSessionId.foreach { expectedId =>
+ require(
+ response.getNewSessionId == expectedId,
+ s"Returned session ID '${response.getNewSessionId}' does not match " +
+ s"requested ID '$expectedId'")
+ }
+
+ // Create a new client with the cloned session ID
+ val newConfiguration = configuration.copy(sessionId =
Some(response.getNewSessionId))
+ new SparkConnectClient(newConfiguration, configuration.createChannel())
+ }
}
object SparkConnectClient {
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectCloneSessionHandler.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectCloneSessionHandler.scala
new file mode 100644
index 000000000000..c0d6da938932
--- /dev/null
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectCloneSessionHandler.scala
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.service
+
+import java.util.UUID
+
+import io.grpc.stub.StreamObserver
+
+import org.apache.spark.connect.proto
+import org.apache.spark.internal.Logging
+
+class SparkConnectCloneSessionHandler(
+ responseObserver: StreamObserver[proto.CloneSessionResponse])
+ extends Logging {
+
+ def handle(request: proto.CloneSessionRequest): Unit = {
+ val sourceKey = SessionKey(request.getUserContext.getUserId,
request.getSessionId)
+ val newSessionId = if (request.hasNewSessionId &&
request.getNewSessionId.nonEmpty) {
+ request.getNewSessionId
+ } else {
+ UUID.randomUUID().toString
+ }
+ val previouslyObservedSessionId =
+ if (request.hasClientObservedServerSideSessionId &&
+ request.getClientObservedServerSideSessionId.nonEmpty) {
+ Some(request.getClientObservedServerSideSessionId)
+ } else {
+ None
+ }
+ // Get the original session to retrieve its server session ID for
validation
+ val originalSessionHolder = SparkConnectService.sessionManager
+ .getIsolatedSession(sourceKey, previouslyObservedSessionId)
+ val clonedSessionHolder = SparkConnectService.sessionManager.cloneSession(
+ sourceKey,
+ newSessionId,
+ previouslyObservedSessionId)
+ val response = proto.CloneSessionResponse
+ .newBuilder()
+ .setSessionId(request.getSessionId)
+ .setNewSessionId(clonedSessionHolder.sessionId)
+ .setServerSideSessionId(originalSessionHolder.serverSessionId)
+ .setNewServerSideSessionId(clonedSessionHolder.serverSessionId)
+ .build()
+ responseObserver.onNext(response)
+ responseObserver.onCompleted()
+ }
+}
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
index cc1cb95b66c4..13ce2d64256b 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
@@ -230,6 +230,20 @@ class SparkConnectService(debug: Boolean) extends
AsyncService with BindableServ
}
}
+ override def cloneSession(
+ request: proto.CloneSessionRequest,
+ responseObserver: StreamObserver[proto.CloneSessionResponse]): Unit = {
+ try {
+ new SparkConnectCloneSessionHandler(responseObserver).handle(request)
+ } catch {
+ ErrorUtils.handleError(
+ "cloneSession",
+ observer = responseObserver,
+ userId = request.getUserContext.getUserId,
+ sessionId = request.getSessionId)
+ }
+ }
+
private def methodWithCustomMarshallers(
methodDesc: MethodDescriptor[Message, Message]):
MethodDescriptor[Message, Message] = {
val recursionLimit =
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala
index 1c3cfd67f132..f28af0379a04 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala
@@ -114,6 +114,73 @@ class SparkConnectSessionManager extends Logging {
Option(getSession(key, None))
}
+ /**
+ * Clone an existing session with a new session ID. Creates a new
SessionHolder with a cloned
+ * SparkSession that shares configuration and catalog state but has
independent caches and
+ * runtime state.
+ */
+ private[connect] def cloneSession(
+ sourceKey: SessionKey,
+ newSessionId: String,
+ previouslyObservedSessionId: Option[String]): SessionHolder = {
+
+ // Get source session (must exist)
+ val sourceSessionHolder = getIsolatedSession(sourceKey, None)
+
+ previouslyObservedSessionId.foreach(sessionId =>
+ validateSessionId(sourceKey, sourceSessionHolder.session.sessionUUID,
sessionId))
+
+ val newKey = SessionKey(sourceKey.userId, newSessionId)
+
+ // Validate new sessionId for clone operation
+ validateCloneTargetSession(newKey)
+
+ // Create cloned session
+ val clonedSessionHolder = getSession(
+ newKey,
+ Some(() => {
+ val session = sessionStore.get(newKey)
+ if (session == null) {
+ // Clone the underlying SparkSession using cloneSession() which
preserves
+ // configuration, catalog, session state, temporary views, and
registered functions
+ val clonedSparkSession = sourceSessionHolder.session.cloneSession()
+
+ val newHolder = SessionHolder(newKey.userId, newKey.sessionId,
clonedSparkSession)
+ newHolder.initializeSession()
+ newHolder
+ } else {
+ // A session was created in the meantime.
+ session
+ }
+ }))
+
+ clonedSessionHolder
+ }
+
+ private def validateCloneTargetSession(newKey: SessionKey): Unit = {
+ // Validate that sessionId is formatted like UUID before creating session.
+ try {
+ UUID.fromString(newKey.sessionId).toString
+ } catch {
+ case _: IllegalArgumentException =>
+ throw new SparkSQLException(
+ errorClass =
"INVALID_CLONE_SESSION_REQUEST.TARGET_SESSION_ID_FORMAT",
+ messageParameters = Map("targetSessionId" -> newKey.sessionId))
+ }
+ // Validate that session with that key has not been already closed.
+ if (closedSessionsCache.getIfPresent(newKey) != null) {
+ throw new SparkSQLException(
+ errorClass =
"INVALID_CLONE_SESSION_REQUEST.TARGET_SESSION_ID_ALREADY_CLOSED",
+ messageParameters = Map("targetSessionId" -> newKey.sessionId))
+ }
+ // Validate that session with that key does not already exist.
+ if (sessionStore.containsKey(newKey)) {
+ throw new SparkSQLException(
+ errorClass =
"INVALID_CLONE_SESSION_REQUEST.TARGET_SESSION_ID_ALREADY_EXISTS",
+ messageParameters = Map("targetSessionId" -> newKey.sessionId))
+ }
+ }
+
private def getSession(key: SessionKey, default: Option[() =>
SessionHolder]): SessionHolder = {
schedulePeriodicChecks() // Starts the maintenance thread if it hasn't
started yet.
diff --git
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectCloneSessionSuite.scala
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectCloneSessionSuite.scala
new file mode 100644
index 000000000000..922c239526f3
--- /dev/null
+++
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectCloneSessionSuite.scala
@@ -0,0 +1,279 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.service
+
+import java.util.UUID
+
+import org.scalatest.BeforeAndAfterEach
+
+import org.apache.spark.SparkSQLException
+import org.apache.spark.sql.test.SharedSparkSession
+
+class SparkConnectCloneSessionSuite extends SharedSparkSession with
BeforeAndAfterEach {
+
+ override def beforeEach(): Unit = {
+ super.beforeEach()
+ SparkConnectService.sessionManager.invalidateAllSessions()
+ }
+
+ test("clone session with invalid target session ID format") {
+ val sourceKey = SessionKey("testUser", UUID.randomUUID.toString)
+ val invalidSessionId = "not-a-valid-uuid"
+
+ // Create the source session first
+ SparkConnectService.sessionManager.getOrCreateIsolatedSession(sourceKey,
None)
+
+ val ex = intercept[SparkSQLException] {
+ SparkConnectService.sessionManager.cloneSession(sourceKey,
invalidSessionId, None)
+ }
+
+ assert(ex.getCondition ==
"INVALID_CLONE_SESSION_REQUEST.TARGET_SESSION_ID_FORMAT")
+ assert(ex.getMessage.contains("Target session ID not-a-valid-uuid"))
+ assert(ex.getMessage.contains("must be an UUID string"))
+ }
+
+ test("clone session with target session ID already closed") {
+ val sourceKey = SessionKey("testUser", UUID.randomUUID.toString)
+ val targetSessionId = UUID.randomUUID.toString
+ val targetKey = SessionKey("testUser", targetSessionId)
+
+ // Create and then close a session to put it in the closed cache
+ SparkConnectService.sessionManager.getOrCreateIsolatedSession(sourceKey,
None)
+ SparkConnectService.sessionManager.getOrCreateIsolatedSession(targetKey,
None)
+ SparkConnectService.sessionManager.closeSession(targetKey)
+
+ val ex = intercept[SparkSQLException] {
+ SparkConnectService.sessionManager.cloneSession(sourceKey,
targetSessionId, None)
+ }
+
+ assert(
+ ex.getCondition ==
+ "INVALID_CLONE_SESSION_REQUEST.TARGET_SESSION_ID_ALREADY_CLOSED")
+ assert(ex.getMessage.contains(s"target session ID $targetSessionId"))
+ assert(ex.getMessage.contains("was previously closed"))
+ }
+
+ test("clone session with target session ID already exists") {
+ val sourceKey = SessionKey("testUser", UUID.randomUUID.toString)
+ val targetSessionId = UUID.randomUUID.toString
+ val targetKey = SessionKey("testUser", targetSessionId)
+
+ // Create both source and target sessions
+ SparkConnectService.sessionManager.getOrCreateIsolatedSession(sourceKey,
None)
+ SparkConnectService.sessionManager.getOrCreateIsolatedSession(targetKey,
None)
+
+ val ex = intercept[SparkSQLException] {
+ SparkConnectService.sessionManager.cloneSession(sourceKey,
targetSessionId, None)
+ }
+
+ assert(
+ ex.getCondition ==
+ "INVALID_CLONE_SESSION_REQUEST.TARGET_SESSION_ID_ALREADY_EXISTS")
+ assert(ex.getMessage.contains(s"target session ID $targetSessionId"))
+ assert(ex.getMessage.contains("already exists"))
+ }
+
+ test("clone session with source session not found") {
+ val sourceKey = SessionKey("testUser", UUID.randomUUID.toString)
+ val targetSessionId = UUID.randomUUID.toString
+
+ // Don't create the source session, so it doesn't exist
+ val ex = intercept[SparkSQLException] {
+ SparkConnectService.sessionManager.cloneSession(sourceKey,
targetSessionId, None)
+ }
+
+ // Source session errors should remain as standard INVALID_HANDLE errors
+ assert(ex.getCondition == "INVALID_HANDLE.SESSION_NOT_FOUND")
+ assert(ex.getMessage.contains("Session not found"))
+ }
+
+ test("successful clone session creates new session") {
+ val sourceKey = SessionKey("testUser", UUID.randomUUID.toString)
+ val targetSessionId = UUID.randomUUID.toString
+
+ // Create source session
+ val sourceSession = SparkConnectService.sessionManager
+ .getOrCreateIsolatedSession(sourceKey, None)
+
+ // Clone the session
+ val clonedSession = SparkConnectService.sessionManager
+ .cloneSession(sourceKey, targetSessionId, None)
+
+ // Verify the cloned session has the expected session ID
+ assert(clonedSession.sessionId == targetSessionId)
+ assert(clonedSession.sessionId != sourceSession.sessionId)
+
+ // Both sessions should be different objects
+ assert(clonedSession != sourceSession)
+ }
+
+ test("cloned session copies all state (configs, temp views, UDFs, current
database)") {
+ val sourceKey = SessionKey("testUser", UUID.randomUUID.toString)
+ val targetSessionId = UUID.randomUUID.toString
+
+ // Create source session and set up state
+ val sourceSession = SparkConnectService.sessionManager
+ .getOrCreateIsolatedSession(sourceKey, None)
+
+ // Set SQL configs
+ sourceSession.session.conf.set("spark.sql.custom.test.config",
"test-value")
+ sourceSession.session.conf.set("spark.sql.shuffle.partitions", "42")
+
+ // Create temp views
+ sourceSession.session.sql("CREATE TEMPORARY VIEW temp_view1 AS SELECT 1 as
col1, 2 as col2")
+ sourceSession.session.sql("CREATE TEMPORARY VIEW temp_view2 AS SELECT 3 as
col3")
+
+ // Register UDFs
+ sourceSession.session.udf.register("double_value", (x: Int) => x * 2)
+ sourceSession.session.udf.register("concat_strings", (a: String, b:
String) => a + b)
+
+ // Change database
+ sourceSession.session.sql("CREATE DATABASE IF NOT EXISTS test_db")
+ sourceSession.session.sql("USE test_db")
+
+ // Clone the session
+ val clonedSession = SparkConnectService.sessionManager
+ .cloneSession(sourceKey, targetSessionId, None)
+
+ // Verify all state is copied
+
+ // Configs are copied
+ assert(clonedSession.session.conf.get("spark.sql.custom.test.config") ==
"test-value")
+ assert(clonedSession.session.conf.get("spark.sql.shuffle.partitions") ==
"42")
+
+ // Temp views are accessible
+ val view1Result = clonedSession.session.sql("SELECT * FROM
temp_view1").collect()
+ assert(view1Result.length == 1)
+ assert(view1Result(0).getInt(0) == 1)
+ assert(view1Result(0).getInt(1) == 2)
+
+ val view2Result = clonedSession.session.sql("SELECT * FROM
temp_view2").collect()
+ assert(view2Result.length == 1)
+ assert(view2Result(0).getInt(0) == 3)
+
+ // UDFs are accessible
+ val udfResult1 = clonedSession.session.sql("SELECT
double_value(5)").collect()
+ assert(udfResult1(0).getInt(0) == 10)
+
+ val udfResult2 =
+ clonedSession.session.sql("SELECT concat_strings('hello',
'world')").collect()
+ assert(udfResult2(0).getString(0) == "helloworld")
+
+ // Current database is copied
+ assert(clonedSession.session.catalog.currentDatabase == "test_db")
+ }
+
+ test("sessions are independent after cloning (configs, temp views, UDFs)") {
+ val sourceKey = SessionKey("testUser", UUID.randomUUID.toString)
+ val targetSessionId = UUID.randomUUID.toString
+
+ // Create and set up source session
+ val sourceSession = SparkConnectService.sessionManager
+ .getOrCreateIsolatedSession(sourceKey, None)
+ sourceSession.session.conf.set("spark.sql.custom.config", "initial")
+ sourceSession.session.sql("CREATE TEMPORARY VIEW shared_view AS SELECT 1
as value")
+ sourceSession.session.udf.register("shared_udf", (x: Int) => x + 1)
+
+ // Clone the session
+ val clonedSession = SparkConnectService.sessionManager
+ .cloneSession(sourceKey, targetSessionId, None)
+
+ // Test independence of configs
+ sourceSession.session.conf.set("spark.sql.custom.config",
"modified-source")
+ clonedSession.session.conf.set("spark.sql.custom.config", "modified-clone")
+ assert(sourceSession.session.conf.get("spark.sql.custom.config") ==
"modified-source")
+ assert(clonedSession.session.conf.get("spark.sql.custom.config") ==
"modified-clone")
+
+ // Test independence of temp views - modify shared view differently
+ sourceSession.session.sql(
+ "CREATE OR REPLACE TEMPORARY VIEW shared_view AS SELECT 10 as value")
+ clonedSession.session.sql(
+ "CREATE OR REPLACE TEMPORARY VIEW shared_view AS SELECT 20 as value")
+
+ // Each session should see its own version of the view
+ val sourceViewResult = sourceSession.session.sql("SELECT * FROM
shared_view").collect()
+ assert(sourceViewResult(0).getInt(0) == 10)
+
+ val cloneViewResult = clonedSession.session.sql("SELECT * FROM
shared_view").collect()
+ assert(cloneViewResult(0).getInt(0) == 20)
+
+ // Test independence of UDFs
+ sourceSession.session.udf.register("shared_udf", (x: Int) => x + 10)
+ clonedSession.session.udf.register("shared_udf", (x: Int) => x + 100)
+ assert(sourceSession.session.sql("SELECT
shared_udf(5)").collect()(0).getInt(0) == 15)
+ assert(clonedSession.session.sql("SELECT
shared_udf(5)").collect()(0).getInt(0) == 105)
+ }
+
+ test("cloned session copies artifacts and maintains independence") {
+ import java.nio.charset.StandardCharsets
+ import java.nio.file.{Files, Paths}
+
+ val sourceKey = SessionKey("testUser", UUID.randomUUID.toString)
+ val targetSessionId = UUID.randomUUID.toString
+
+ // Create source session
+ val sourceSession = SparkConnectService.sessionManager
+ .getOrCreateIsolatedSession(sourceKey, None)
+
+ // Add some test artifacts to source session
+ val tempFile = Files.createTempFile("test-artifact", ".txt")
+ Files.write(tempFile, "test content".getBytes(StandardCharsets.UTF_8))
+
+ // Add artifact to source session
+ val remotePath = Paths.get("test/artifact.txt")
+ sourceSession.artifactManager.addArtifact(remotePath, tempFile, None)
+
+ // Clone the session
+ val clonedSession = SparkConnectService.sessionManager
+ .cloneSession(sourceKey, targetSessionId, None)
+
+ // Verify sessions have different artifact managers
+ assert(sourceSession.artifactManager ne clonedSession.artifactManager)
+ assert(sourceSession.session.sessionUUID !=
clonedSession.session.sessionUUID)
+
+ // Test independence: add new artifacts to each session
+ val sourceOnlyFile = Files.createTempFile("source-only", ".txt")
+ Files.write(sourceOnlyFile, "source only
content".getBytes(StandardCharsets.UTF_8))
+ val sourceOnlyPath = Paths.get("jars/source.jar")
+ sourceSession.artifactManager.addArtifact(sourceOnlyPath, sourceOnlyFile,
None)
+
+ val clonedOnlyFile = Files.createTempFile("cloned-only", ".txt")
+ Files.write(clonedOnlyFile, "cloned only
content".getBytes(StandardCharsets.UTF_8))
+ val clonedOnlyPath = Paths.get("jars/cloned.jar")
+ clonedSession.artifactManager.addArtifact(clonedOnlyPath, clonedOnlyFile,
None)
+
+ // Use getAddedJars to verify independence (since it's a public API)
+ val sourceJars = sourceSession.artifactManager.getAddedJars.map(_.toString)
+ val clonedJars = clonedSession.artifactManager.getAddedJars.map(_.toString)
+
+ // Source should have source.jar but not cloned.jar
+ assert(sourceJars.exists(_.contains("source.jar")))
+ assert(!sourceJars.exists(_.contains("cloned.jar")))
+
+ // Cloned should have cloned.jar but not source.jar
+ assert(clonedJars.exists(_.contains("cloned.jar")))
+ assert(!clonedJars.exists(_.contains("source.jar")))
+
+ // Clean up
+ Files.deleteIfExists(tempFile)
+ Files.deleteIfExists(sourceOnlyFile)
+ Files.deleteIfExists(clonedOnlyFile)
+ sourceSession.artifactManager.close()
+ clonedSession.artifactManager.close()
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]