This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.0 by this push:
new e10676638f74 [SPARK-51650][ML][CONNECT] Support delete ml cached
objects in batch
e10676638f74 is described below
commit e10676638f7484509ffe4bddc62e297638330545
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Tue Apr 1 09:08:36 2025 +0800
[SPARK-51650][ML][CONNECT] Support delete ml cached objects in batch
### What changes were proposed in this pull request?
Support delete ml cached objects in batch
### Why are the changes needed?
to save RPCs
meta algorithms in client side may generate/delete many models, e.g.
`CrossValidator`.
Existing implementation will have to delete them on by one, while with this
change, they can be deleted in batch.
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
existing tests
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #50441 from zhengruifeng/ml_connect_batch_del.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
(cherry picked from commit c83b88a037eee12f485498704fe587c76e69e8d3)
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/sql/connect/client/core.py | 17 ++++++------
python/pyspark/sql/connect/proto/ml_pb2.py | 30 +++++++++++-----------
python/pyspark/sql/connect/proto/ml_pb2.pyi | 18 +++++++------
.../src/main/protobuf/spark/connect/ml.proto | 4 +--
.../apache/spark/sql/connect/ml/MLHandler.scala | 9 ++++---
5 files changed, 41 insertions(+), 37 deletions(-)
diff --git a/python/pyspark/sql/connect/client/core.py
b/python/pyspark/sql/connect/client/core.py
index 3a351c11fc01..4c5a262ebdc8 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -1971,14 +1971,17 @@ class SparkConnectClient(object):
self.thread_local.ml_caches = set()
if cache_id in self.thread_local.ml_caches:
- self._delete_ml_cache(cache_id)
+ self._delete_ml_cache([cache_id])
- def _delete_ml_cache(self, cache_id: str) -> None:
+ def _delete_ml_cache(self, cache_ids: List[str]) -> None:
# try best to delete the cache
try:
- command = pb2.Command()
-
command.ml_command.delete.obj_ref.CopyFrom(pb2.ObjectRef(id=cache_id))
- self.execute_command(command)
+ if len(cache_ids) > 0:
+ command = pb2.Command()
+ command.ml_command.delete.obj_refs.extend(
+ [pb2.ObjectRef(id=cache_id) for cache_id in cache_ids]
+ )
+ self.execute_command(command)
except Exception:
pass
@@ -1987,6 +1990,4 @@ class SparkConnectClient(object):
self.thread_local.ml_caches = set()
self.disable_reattachable_execute()
- # Todo add a pattern to delete all model in one command
- for model_id in self.thread_local.ml_caches:
- self._delete_ml_cache(model_id)
+ self._delete_ml_cache([model_id for model_id in
self.thread_local.ml_caches])
diff --git a/python/pyspark/sql/connect/proto/ml_pb2.py
b/python/pyspark/sql/connect/proto/ml_pb2.py
index 666cb1efdd2b..a2475a94c447 100644
--- a/python/pyspark/sql/connect/proto/ml_pb2.py
+++ b/python/pyspark/sql/connect/proto/ml_pb2.py
@@ -40,7 +40,7 @@ from pyspark.sql.connect.proto import ml_common_pb2 as
spark_dot_connect_dot_ml_
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x16spark/connect/ml.proto\x12\rspark.connect\x1a\x1dspark/connect/relations.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/ml_common.proto"\xfb\t\n\tMlCommand\x12\x30\n\x03\x66it\x18\x01
\x01(\x0b\x32\x1c.spark.connect.MlCommand.FitH\x00R\x03\x66it\x12,\n\x05\x66\x65tch\x18\x02
\x01(\x0b\x32\x14.spark.connect.FetchH\x00R\x05\x66\x65tch\x12\x39\n\x06\x64\x65lete\x18\x03
\x01(\x0b\x32\x1f.spark.connect.MlCommand.DeleteH\x00R\x06\x64\x65lete\x12\x36\n\x05write\x1
[...]
+
b'\n\x16spark/connect/ml.proto\x12\rspark.connect\x1a\x1dspark/connect/relations.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/ml_common.proto"\xfd\t\n\tMlCommand\x12\x30\n\x03\x66it\x18\x01
\x01(\x0b\x32\x1c.spark.connect.MlCommand.FitH\x00R\x03\x66it\x12,\n\x05\x66\x65tch\x18\x02
\x01(\x0b\x32\x14.spark.connect.FetchH\x00R\x05\x66\x65tch\x12\x39\n\x06\x64\x65lete\x18\x03
\x01(\x0b\x32\x1f.spark.connect.MlCommand.DeleteH\x00R\x06\x64\x65lete\x12\x36\n\x05write\x1
[...]
)
_globals = globals()
@@ -54,21 +54,21 @@ if not _descriptor._USE_C_DESCRIPTORS:
_globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._loaded_options = None
_globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_options = b"8\001"
_globals["_MLCOMMAND"]._serialized_start = 137
- _globals["_MLCOMMAND"]._serialized_end = 1412
+ _globals["_MLCOMMAND"]._serialized_end = 1414
_globals["_MLCOMMAND_FIT"]._serialized_start = 480
_globals["_MLCOMMAND_FIT"]._serialized_end = 658
_globals["_MLCOMMAND_DELETE"]._serialized_start = 660
- _globals["_MLCOMMAND_DELETE"]._serialized_end = 719
- _globals["_MLCOMMAND_WRITE"]._serialized_start = 722
- _globals["_MLCOMMAND_WRITE"]._serialized_end = 1132
- _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_start = 1034
- _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_end = 1092
- _globals["_MLCOMMAND_READ"]._serialized_start = 1134
- _globals["_MLCOMMAND_READ"]._serialized_end = 1215
- _globals["_MLCOMMAND_EVALUATE"]._serialized_start = 1218
- _globals["_MLCOMMAND_EVALUATE"]._serialized_end = 1401
- _globals["_MLCOMMANDRESULT"]._serialized_start = 1415
- _globals["_MLCOMMANDRESULT"]._serialized_end = 1818
- _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_start = 1608
- _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_end = 1803
+ _globals["_MLCOMMAND_DELETE"]._serialized_end = 721
+ _globals["_MLCOMMAND_WRITE"]._serialized_start = 724
+ _globals["_MLCOMMAND_WRITE"]._serialized_end = 1134
+ _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_start = 1036
+ _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_end = 1094
+ _globals["_MLCOMMAND_READ"]._serialized_start = 1136
+ _globals["_MLCOMMAND_READ"]._serialized_end = 1217
+ _globals["_MLCOMMAND_EVALUATE"]._serialized_start = 1220
+ _globals["_MLCOMMAND_EVALUATE"]._serialized_end = 1403
+ _globals["_MLCOMMANDRESULT"]._serialized_start = 1417
+ _globals["_MLCOMMANDRESULT"]._serialized_end = 1820
+ _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_start = 1610
+ _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_end = 1805
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/ml_pb2.pyi
b/python/pyspark/sql/connect/proto/ml_pb2.pyi
index 3a1e9155d71d..1d2e06d536e1 100644
--- a/python/pyspark/sql/connect/proto/ml_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/ml_pb2.pyi
@@ -111,25 +111,27 @@ class MlCommand(google.protobuf.message.Message):
) -> typing_extensions.Literal["params"] | None: ...
class Delete(google.protobuf.message.Message):
- """Command to delete the cached object which could be a model
+ """Command to delete the cached objects which could be a model
or summary evaluated by a model
"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
- OBJ_REF_FIELD_NUMBER: builtins.int
+ OBJ_REFS_FIELD_NUMBER: builtins.int
@property
- def obj_ref(self) ->
pyspark.sql.connect.proto.ml_common_pb2.ObjectRef: ...
+ def obj_refs(
+ self,
+ ) ->
google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
+ pyspark.sql.connect.proto.ml_common_pb2.ObjectRef
+ ]: ...
def __init__(
self,
*,
- obj_ref: pyspark.sql.connect.proto.ml_common_pb2.ObjectRef | None
= ...,
+ obj_refs:
collections.abc.Iterable[pyspark.sql.connect.proto.ml_common_pb2.ObjectRef]
+ | None = ...,
) -> None: ...
- def HasField(
- self, field_name: typing_extensions.Literal["obj_ref", b"obj_ref"]
- ) -> builtins.bool: ...
def ClearField(
- self, field_name: typing_extensions.Literal["obj_ref", b"obj_ref"]
+ self, field_name: typing_extensions.Literal["obj_refs",
b"obj_refs"]
) -> None: ...
class Write(google.protobuf.message.Message):
diff --git a/sql/connect/common/src/main/protobuf/spark/connect/ml.proto
b/sql/connect/common/src/main/protobuf/spark/connect/ml.proto
index 6e469bb9027e..cdb077f5055a 100644
--- a/sql/connect/common/src/main/protobuf/spark/connect/ml.proto
+++ b/sql/connect/common/src/main/protobuf/spark/connect/ml.proto
@@ -48,10 +48,10 @@ message MlCommand {
Relation dataset = 3;
}
- // Command to delete the cached object which could be a model
+ // Command to delete the cached objects which could be a model
// or summary evaluated by a model
message Delete {
- ObjectRef obj_ref = 1;
+ repeated ObjectRef obj_refs = 1;
}
// Command to write ML operator
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
index 9a9e156f91cd..9e3472ba8181 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
@@ -171,11 +171,12 @@ private[connect] object MLHandler extends Logging {
}
case proto.MlCommand.CommandCase.DELETE =>
- val objId = mlCommand.getDelete.getObjRef.getId
var result = false
- if (!objId.contains(".")) {
- mlCache.remove(objId)
- result = true
+ mlCommand.getDelete.getObjRefsList.asScala.toArray.foreach { objId =>
+ if (!objId.getId.contains(".")) {
+ mlCache.remove(objId.getId)
+ result = true
+ }
}
proto.MlCommandResult
.newBuilder()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]