This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.0 by this push:
     new e10676638f74 [SPARK-51650][ML][CONNECT] Support delete ml cached 
objects in batch
e10676638f74 is described below

commit e10676638f7484509ffe4bddc62e297638330545
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Tue Apr 1 09:08:36 2025 +0800

    [SPARK-51650][ML][CONNECT] Support delete ml cached objects in batch
    
    ### What changes were proposed in this pull request?
    Support delete ml cached objects in batch
    
    ### Why are the changes needed?
    to save RPCs
    
    meta algorithms in client side may generate/delete many models, e.g. 
`CrossValidator`.
    Existing implementation will have to delete them on by one, while with this 
change, they can be deleted in batch.
    
    ### Does this PR introduce _any_ user-facing change?
    no
    
    ### How was this patch tested?
    existing tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #50441 from zhengruifeng/ml_connect_batch_del.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
    (cherry picked from commit c83b88a037eee12f485498704fe587c76e69e8d3)
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 python/pyspark/sql/connect/client/core.py          | 17 ++++++------
 python/pyspark/sql/connect/proto/ml_pb2.py         | 30 +++++++++++-----------
 python/pyspark/sql/connect/proto/ml_pb2.pyi        | 18 +++++++------
 .../src/main/protobuf/spark/connect/ml.proto       |  4 +--
 .../apache/spark/sql/connect/ml/MLHandler.scala    |  9 ++++---
 5 files changed, 41 insertions(+), 37 deletions(-)

diff --git a/python/pyspark/sql/connect/client/core.py 
b/python/pyspark/sql/connect/client/core.py
index 3a351c11fc01..4c5a262ebdc8 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -1971,14 +1971,17 @@ class SparkConnectClient(object):
             self.thread_local.ml_caches = set()
 
         if cache_id in self.thread_local.ml_caches:
-            self._delete_ml_cache(cache_id)
+            self._delete_ml_cache([cache_id])
 
-    def _delete_ml_cache(self, cache_id: str) -> None:
+    def _delete_ml_cache(self, cache_ids: List[str]) -> None:
         # try best to delete the cache
         try:
-            command = pb2.Command()
-            
command.ml_command.delete.obj_ref.CopyFrom(pb2.ObjectRef(id=cache_id))
-            self.execute_command(command)
+            if len(cache_ids) > 0:
+                command = pb2.Command()
+                command.ml_command.delete.obj_refs.extend(
+                    [pb2.ObjectRef(id=cache_id) for cache_id in cache_ids]
+                )
+                self.execute_command(command)
         except Exception:
             pass
 
@@ -1987,6 +1990,4 @@ class SparkConnectClient(object):
             self.thread_local.ml_caches = set()
 
         self.disable_reattachable_execute()
-        # Todo add a pattern to delete all model in one command
-        for model_id in self.thread_local.ml_caches:
-            self._delete_ml_cache(model_id)
+        self._delete_ml_cache([model_id for model_id in 
self.thread_local.ml_caches])
diff --git a/python/pyspark/sql/connect/proto/ml_pb2.py 
b/python/pyspark/sql/connect/proto/ml_pb2.py
index 666cb1efdd2b..a2475a94c447 100644
--- a/python/pyspark/sql/connect/proto/ml_pb2.py
+++ b/python/pyspark/sql/connect/proto/ml_pb2.py
@@ -40,7 +40,7 @@ from pyspark.sql.connect.proto import ml_common_pb2 as 
spark_dot_connect_dot_ml_
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x16spark/connect/ml.proto\x12\rspark.connect\x1a\x1dspark/connect/relations.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/ml_common.proto"\xfb\t\n\tMlCommand\x12\x30\n\x03\x66it\x18\x01
 
\x01(\x0b\x32\x1c.spark.connect.MlCommand.FitH\x00R\x03\x66it\x12,\n\x05\x66\x65tch\x18\x02
 
\x01(\x0b\x32\x14.spark.connect.FetchH\x00R\x05\x66\x65tch\x12\x39\n\x06\x64\x65lete\x18\x03
 
\x01(\x0b\x32\x1f.spark.connect.MlCommand.DeleteH\x00R\x06\x64\x65lete\x12\x36\n\x05write\x1
 [...]
+    
b'\n\x16spark/connect/ml.proto\x12\rspark.connect\x1a\x1dspark/connect/relations.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/ml_common.proto"\xfd\t\n\tMlCommand\x12\x30\n\x03\x66it\x18\x01
 
\x01(\x0b\x32\x1c.spark.connect.MlCommand.FitH\x00R\x03\x66it\x12,\n\x05\x66\x65tch\x18\x02
 
\x01(\x0b\x32\x14.spark.connect.FetchH\x00R\x05\x66\x65tch\x12\x39\n\x06\x64\x65lete\x18\x03
 
\x01(\x0b\x32\x1f.spark.connect.MlCommand.DeleteH\x00R\x06\x64\x65lete\x12\x36\n\x05write\x1
 [...]
 )
 
 _globals = globals()
@@ -54,21 +54,21 @@ if not _descriptor._USE_C_DESCRIPTORS:
     _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._loaded_options = None
     _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_options = b"8\001"
     _globals["_MLCOMMAND"]._serialized_start = 137
-    _globals["_MLCOMMAND"]._serialized_end = 1412
+    _globals["_MLCOMMAND"]._serialized_end = 1414
     _globals["_MLCOMMAND_FIT"]._serialized_start = 480
     _globals["_MLCOMMAND_FIT"]._serialized_end = 658
     _globals["_MLCOMMAND_DELETE"]._serialized_start = 660
-    _globals["_MLCOMMAND_DELETE"]._serialized_end = 719
-    _globals["_MLCOMMAND_WRITE"]._serialized_start = 722
-    _globals["_MLCOMMAND_WRITE"]._serialized_end = 1132
-    _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_start = 1034
-    _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_end = 1092
-    _globals["_MLCOMMAND_READ"]._serialized_start = 1134
-    _globals["_MLCOMMAND_READ"]._serialized_end = 1215
-    _globals["_MLCOMMAND_EVALUATE"]._serialized_start = 1218
-    _globals["_MLCOMMAND_EVALUATE"]._serialized_end = 1401
-    _globals["_MLCOMMANDRESULT"]._serialized_start = 1415
-    _globals["_MLCOMMANDRESULT"]._serialized_end = 1818
-    _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_start = 1608
-    _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_end = 1803
+    _globals["_MLCOMMAND_DELETE"]._serialized_end = 721
+    _globals["_MLCOMMAND_WRITE"]._serialized_start = 724
+    _globals["_MLCOMMAND_WRITE"]._serialized_end = 1134
+    _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_start = 1036
+    _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_end = 1094
+    _globals["_MLCOMMAND_READ"]._serialized_start = 1136
+    _globals["_MLCOMMAND_READ"]._serialized_end = 1217
+    _globals["_MLCOMMAND_EVALUATE"]._serialized_start = 1220
+    _globals["_MLCOMMAND_EVALUATE"]._serialized_end = 1403
+    _globals["_MLCOMMANDRESULT"]._serialized_start = 1417
+    _globals["_MLCOMMANDRESULT"]._serialized_end = 1820
+    _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_start = 1610
+    _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_end = 1805
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/ml_pb2.pyi 
b/python/pyspark/sql/connect/proto/ml_pb2.pyi
index 3a1e9155d71d..1d2e06d536e1 100644
--- a/python/pyspark/sql/connect/proto/ml_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/ml_pb2.pyi
@@ -111,25 +111,27 @@ class MlCommand(google.protobuf.message.Message):
         ) -> typing_extensions.Literal["params"] | None: ...
 
     class Delete(google.protobuf.message.Message):
-        """Command to delete the cached object which could be a model
+        """Command to delete the cached objects which could be a model
         or summary evaluated by a model
         """
 
         DESCRIPTOR: google.protobuf.descriptor.Descriptor
 
-        OBJ_REF_FIELD_NUMBER: builtins.int
+        OBJ_REFS_FIELD_NUMBER: builtins.int
         @property
-        def obj_ref(self) -> 
pyspark.sql.connect.proto.ml_common_pb2.ObjectRef: ...
+        def obj_refs(
+            self,
+        ) -> 
google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
+            pyspark.sql.connect.proto.ml_common_pb2.ObjectRef
+        ]: ...
         def __init__(
             self,
             *,
-            obj_ref: pyspark.sql.connect.proto.ml_common_pb2.ObjectRef | None 
= ...,
+            obj_refs: 
collections.abc.Iterable[pyspark.sql.connect.proto.ml_common_pb2.ObjectRef]
+            | None = ...,
         ) -> None: ...
-        def HasField(
-            self, field_name: typing_extensions.Literal["obj_ref", b"obj_ref"]
-        ) -> builtins.bool: ...
         def ClearField(
-            self, field_name: typing_extensions.Literal["obj_ref", b"obj_ref"]
+            self, field_name: typing_extensions.Literal["obj_refs", 
b"obj_refs"]
         ) -> None: ...
 
     class Write(google.protobuf.message.Message):
diff --git a/sql/connect/common/src/main/protobuf/spark/connect/ml.proto 
b/sql/connect/common/src/main/protobuf/spark/connect/ml.proto
index 6e469bb9027e..cdb077f5055a 100644
--- a/sql/connect/common/src/main/protobuf/spark/connect/ml.proto
+++ b/sql/connect/common/src/main/protobuf/spark/connect/ml.proto
@@ -48,10 +48,10 @@ message MlCommand {
     Relation dataset = 3;
   }
 
-  // Command to delete the cached object which could be a model
+  // Command to delete the cached objects which could be a model
   // or summary evaluated by a model
   message Delete {
-    ObjectRef obj_ref = 1;
+    repeated ObjectRef obj_refs = 1;
   }
 
   // Command to write ML operator
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
index 9a9e156f91cd..9e3472ba8181 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala
@@ -171,11 +171,12 @@ private[connect] object MLHandler extends Logging {
         }
 
       case proto.MlCommand.CommandCase.DELETE =>
-        val objId = mlCommand.getDelete.getObjRef.getId
         var result = false
-        if (!objId.contains(".")) {
-          mlCache.remove(objId)
-          result = true
+        mlCommand.getDelete.getObjRefsList.asScala.toArray.foreach { objId =>
+          if (!objId.getId.contains(".")) {
+            mlCache.remove(objId.getId)
+            result = true
+          }
         }
         proto.MlCommandResult
           .newBuilder()


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to