This is an automated email from the ASF dual-hosted git repository.
weichenxu123 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new fb8f5f69bc63 [SPARK-54708][CONNECT][ML] Optimize ML cache cleanup with
lazy directory creation
fb8f5f69bc63 is described below
commit fb8f5f69bc635f9cfd1b8103bde929638ae3628c
Author: Xi Lyu <[email protected]>
AuthorDate: Tue Dec 16 09:53:08 2025 +0800
[SPARK-54708][CONNECT][ML] Optimize ML cache cleanup with lazy directory
creation
### What changes were proposed in this pull request?
In current implementation, during the session holder cleanup of non-ML
sessions, in mlCache.clear() ([code
link](https://github.com/apache/spark/blob/43f7936d7b3a4701e3d0fdb44663006cbe0db70b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala#L381))
, the offloadedModelsDir is still eagerly created and will need to be deleted,
which will cause ~10 ms unnecessary latency.
In this PR, we are making the directory creation lazy to avoid deleting
empty directories when there is no SparkML operations in sessions.
### Why are the changes needed?
Improve the performance of ReleaseSession RPC by ~10ms.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
New test and existing tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #53475 from xi-db/lazy-ml-dir-creation.
Authored-by: Xi Lyu <[email protected]>
Signed-off-by: Weichen Xu <[email protected]>
---
.../org/apache/spark/sql/connect/ml/MLCache.scala | 29 +++++++++++++++++++---
.../spark/sql/connect/service/SessionHolder.scala | 3 ++-
.../org/apache/spark/sql/connect/ml/MLSuite.scala | 29 ++++++++++++++++++++++
3 files changed, 56 insertions(+), 5 deletions(-)
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala
index 7761c0078b27..e130e637e820 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala
@@ -20,9 +20,10 @@ import java.io.File
import java.nio.file.{Files, Path, Paths}
import java.util.UUID
import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, TimeUnit}
-import java.util.concurrent.atomic.AtomicLong
+import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong}
import scala.collection.mutable
+import scala.util.control.NonFatal
import com.google.common.cache.{CacheBuilder, RemovalNotification}
@@ -44,12 +45,17 @@ private[connect] class MLCache(sessionHolder:
SessionHolder) extends Logging {
private[ml] val totalMLCacheInMemorySizeBytes: AtomicLong = new AtomicLong(0)
- val offloadedModelsDir: Path = {
- val path = Paths.get(
+ // Track if ML directories were ever created in this session
+ private[ml] val hasCreatedMLDirs: AtomicBoolean = new AtomicBoolean(false)
+
+ lazy val offloadedModelsDir: Path = {
+ val dirPath = Paths.get(
System.getProperty("java.io.tmpdir"),
"spark_connect_model_cache",
sessionHolder.sessionId)
- Files.createDirectories(path)
+ val createdPath = Files.createDirectories(dirPath)
+ hasCreatedMLDirs.set(true)
+ createdPath
}
private[spark] def getMemoryControlEnabled: Boolean = {
sessionHolder.session.conf.get(
@@ -173,6 +179,21 @@ private[connect] class MLCache(sessionHolder:
SessionHolder) extends Logging {
}
}
+ /**
+ * Closes the MLCache and cleans up resources. Only performs cleanup if ML
directories or models
+ * were created during the session. Called by SessionHolder during session
cleanup.
+ */
+ def close(): Unit = {
+ if (hasCreatedMLDirs.get() || cachedModel.size() > 0) {
+ try {
+ clear()
+ } catch {
+ case NonFatal(e) =>
+ logWarning(log"Failed to cleanup ML cache resources", e)
+ }
+ }
+ }
+
/**
* Get the object by the key
* @param refId
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
index f3128ce50840..d0d0f0ba750a 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
@@ -378,7 +378,8 @@ case class SessionHolder(userId: String, sessionId: String,
session: SparkSessio
// remove all executions and no new executions will be added in the
meanwhile.
SparkConnectService.executionManager.removeAllExecutionsForSession(this.key)
- mlCache.clear()
+ // Clean up ML cache (only if ML models were created)
+ mlCache.close()
session.cleanupPythonWorkerLogs()
diff --git
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala
index bdc094e7b2b7..6c9b1819021f 100644
---
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala
+++
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala
@@ -443,6 +443,35 @@ class MLSuite extends MLHelper {
}
}
+ test("MLCache close()") {
+ // Test 1: Non-ML session - directories should NOT be created
+ val nonMLSessionHolder =
SparkConnectTestUtils.createDummySessionHolder(spark)
+ val mlCache1 = nonMLSessionHolder.mlCache
+
+ assert(!mlCache1.hasCreatedMLDirs.get())
+ mlCache1.close()
+ assert(!mlCache1.hasCreatedMLDirs.get())
+
+ // Test 2: ML session - directories should be created, close should run
cleanup
+ val mlSessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
+ val mlCache2 = mlSessionHolder.mlCache
+
+ val modelId = trainLogisticRegressionModel(mlSessionHolder)
+ assert(mlCache2.hasCreatedMLDirs.get())
+ assert(mlCache2.get(modelId) != null)
+ mlCache2.close()
+ assert(mlCache2.cachedModel.isEmpty)
+
+ // Test 3: Edge case - register then remove model, close should still run
cleanup
+ val edgeCaseSessionHolder =
SparkConnectTestUtils.createDummySessionHolder(spark)
+ val mlCache3 = edgeCaseSessionHolder.mlCache
+ val modelId2 = trainLogisticRegressionModel(edgeCaseSessionHolder)
+ mlCache3.remove(modelId2)
+ assert(mlCache3.hasCreatedMLDirs.get())
+ mlCache3.close()
+ assert(mlCache3.cachedModel.isEmpty)
+ }
+
def trainTreeModel(
sessionHolder: SessionHolder,
estimator: proto.MlOperator.Builder): String = {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]