This is an automated email from the ASF dual-hosted git repository.
hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 59b8a4489c87 [SPARK-54001][SQL] Optimize memory usage in session
cloning with ref-counted cached local relations
59b8a4489c87 is described below
commit 59b8a4489c878fa3a9aa6b7fbae760f2fc80eb9d
Author: pranavdev022 <[email protected]>
AuthorDate: Thu Oct 23 15:31:06 2025 -0400
[SPARK-54001][SQL] Optimize memory usage in session cloning with
ref-counted cached local relations
### What changes were proposed in this pull request?
This PR optimizes memory management for cached local relations when cloning
Spark sessions by implementing reference counting instead of data replication.
**Current behavior:**
- When a session is cloned, cached local relation data stored in the block
manager is replicated.
- Each clone creates a duplicate copy of the data with a new block ID.
- This causes unnecessary memory pressure.
**Proposed changes:**
- Implement reference counting for cached local relations during session
cloning.
- Retain the same block ID and data reference when cloning sessions,
incrementing a ref count instead of copying
- Add a hash-to-blockId mapping in ArtifactManager for efficient block
lookup
- Clean up blocks from block manager memory when ref count reaches zero
### Why are the changes needed?
Cloning sessions is a common operation in Spark applications (e.g., for
creating isolated execution contexts). The current approach of duplicating
cached data can significantly increase memory footprint, especially when:
- Sessions are cloned frequently
- Cached relations contain large datasets
- Multiple clones exist simultaneously
This optimization reduces memory pressure, improves performance by avoiding
unnecessary data copies.
### Does this PR introduce _any_ user-facing change?
No. This is an internal optimization that improves memory efficiency
without changing user-facing APIs or behavior.
### How was this patch tested?
- Added unit tests to verify the reference count logic functioning.
- Existing unit tests for ArtifactManager and session cloning.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #52651 from pranavdev022/clone-artifactmanager-fix.
Authored-by: pranavdev022 <[email protected]>
Signed-off-by: Herman van Hovell <[email protected]>
---
project/MimaExcludes.scala | 5 +-
.../CheckConnectJvmClientCompatibility.scala | 2 +
.../sql/connect/planner/SparkConnectPlanner.scala | 6 +-
.../spark/sql/artifact/ArtifactManager.scala | 102 ++++++++++++++-------
.../spark/sql/artifact/ArtifactManagerSuite.scala | 60 ++++++++++--
5 files changed, 129 insertions(+), 46 deletions(-)
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 6fba665c8b24..1614ec212c2e 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -45,7 +45,10 @@ object MimaExcludes {
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.collection.PrimitiveKeyOpenHashMap*"),
// [SPARK-54041][SQL] Enable Direct Passthrough Partitioning in the
DataFrame API
-
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.Dataset.repartitionById")
+
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.Dataset.repartitionById"),
+
+ // [SPARK-54001][CONNECT] Replace block copying with ref-counting in
ArtifactManager cloning
+
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.artifact.ArtifactManager.cachedBlockIdList")
)
// Default exclude rules
diff --git
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
index cbaa4f5ea07f..92adc8eb9346 100644
---
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
+++
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
@@ -234,6 +234,8 @@ object CheckConnectJvmClientCompatibility {
"org.apache.spark.sql.artifact.ArtifactManager$"),
ProblemFilters.exclude[MissingClassProblem](
"org.apache.spark.sql.artifact.ArtifactManager$SparkContextResourceType$"),
+ ProblemFilters.exclude[MissingClassProblem](
+ "org.apache.spark.sql.artifact.RefCountedCacheId"),
// ColumnNode conversions
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.SparkSession"),
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index a7a8f3506dea..ebcf462b84ce 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -60,7 +60,7 @@ import
org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils}
import org.apache.spark.sql.classic.{Catalog, DataFrameWriter, Dataset,
MergeIntoWriter, RelationalGroupedDataset, SparkSession, TypedAggUtils,
UserDefinedFunctionUtils}
import org.apache.spark.sql.classic.ClassicConversions._
import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
-import org.apache.spark.sql.connect.common.{DataTypeProtoConverter,
ForeachWriterPacket, LiteralValueProtoConverter, StorageLevelProtoConverter,
StreamingListenerPacket, UdfPacket}
+import org.apache.spark.sql.connect.common.{DataTypeProtoConverter,
ForeachWriterPacket, InvalidPlanInput, LiteralValueProtoConverter,
StorageLevelProtoConverter, StreamingListenerPacket, UdfPacket}
import
org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE
import org.apache.spark.sql.connect.ml.MLHandler
import org.apache.spark.sql.connect.pipelines.PipelinesHandler
@@ -1330,7 +1330,9 @@ class SparkConnectPlanner(
private def transformCachedLocalRelation(rel: proto.CachedLocalRelation):
LogicalPlan = {
val blockManager = session.sparkContext.env.blockManager
- val blockId = CacheId(sessionHolder.session.sessionUUID, rel.getHash)
+ val blockId =
session.artifactManager.getCachedBlockId(rel.getHash).getOrElse {
+ throw InvalidPlanInput(s"Cannot find a cached local relation for hash:
${rel.getHash}")
+ }
val bytes = blockManager.getLocalBytes(blockId)
bytes
.map { blockData =>
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala
b/sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala
index de91e5e8a44b..5889fe581d4e 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala
@@ -20,10 +20,9 @@ package org.apache.spark.sql.artifact
import java.io.{File, IOException}
import java.lang.ref.Cleaner
import java.net.{URI, URL, URLClassLoader}
-import java.nio.ByteBuffer
import java.nio.file.{CopyOption, Files, Path, Paths, StandardCopyOption}
-import java.util.concurrent.CopyOnWriteArrayList
-import java.util.concurrent.atomic.AtomicBoolean
+import java.util.concurrent.{ConcurrentHashMap, CopyOnWriteArrayList}
+import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger}
import scala.collection.mutable.ListBuffer
import scala.jdk.CollectionConverters._
@@ -114,7 +113,7 @@ class ArtifactManager(session: SparkSession) extends
AutoCloseable with Logging
}
}
- protected val cachedBlockIdList = new CopyOnWriteArrayList[CacheId]
+ private val hashToCachedIdMap = new ConcurrentHashMap[String,
RefCountedCacheId]
protected val jarsList = new CopyOnWriteArrayList[Path]
protected val pythonIncludeList = new CopyOnWriteArrayList[String]
protected val sparkContextRelativePaths =
@@ -136,6 +135,10 @@ class ArtifactManager(session: SparkSession) extends
AutoCloseable with Logging
*/
def getPythonIncludes: Seq[String] = pythonIncludeList.asScala.toSeq
+ protected[sql] def getCachedBlockId(hash: String): Option[CacheId] = {
+ Option(hashToCachedIdMap.get(hash)).map(_.id)
+ }
+
private def transferFile(
source: Path,
target: Path,
@@ -192,7 +195,14 @@ class ArtifactManager(session: SparkSession) extends
AutoCloseable with Logging
blockSize = tmpFile.length(),
tellMaster = false)
updater.save()
- cachedBlockIdList.add(blockId)
+ val oldBlock = hashToCachedIdMap.put(blockId.hash, new
RefCountedCacheId(blockId))
+ if (oldBlock != null) {
+ logWarning(
+ log"Replacing existing cache artifact with hash
${MDC(LogKeys.BLOCK_ID, blockId)} " +
+ log"in session ${MDC(LogKeys.SESSION_ID, session.sessionUUID)}.
" +
+ log"This may indicate duplicate artifact addition.")
+ oldBlock.release(blockManager)
+ }
}(finallyBlock = { tmpFile.delete() })
} else if
(normalizedRemoteRelativePath.startsWith(s"classes${File.separator}")) {
// Move class files to the right directory.
@@ -354,10 +364,27 @@ class ArtifactManager(session: SparkSession) extends
AutoCloseable with Logging
if (artifactPath.toFile.exists()) {
Utils.copyDirectory(artifactPath.toFile,
newArtifactManager.artifactPath.toFile)
}
- val blockManager = sparkContext.env.blockManager
- val newBlockIds = cachedBlockIdList.asScala.map { blockId =>
- val newBlockId = blockId.copy(sessionUUID = newSession.sessionUUID)
- copyBlock(blockId, newBlockId, blockManager)
+
+ // Share cached blocks with the cloned session by copying the references
and incrementing
+ // their reference counts. Both the original and cloned ArtifactManager
will reference the
+ // same underlying cached data blocks. When either session releases a
block, only the ref-count
+ // decreases.
+ // The block is removed from memory only when the ref-count reaches zero.
+ hashToCachedIdMap.forEach { (hash: String, refCountedCacheId:
RefCountedCacheId) =>
+ try {
+ refCountedCacheId.acquire() // Increment ref-count to prevent
premature cleanup
+ newArtifactManager.hashToCachedIdMap.put(hash, refCountedCacheId)
+ } catch {
+ case e: SparkRuntimeException if e.getCondition ==
"BLOCK_ALREADY_RELEASED" =>
+ // The parent session was closed or this block was released during
cloning.
+ // This indicates a race condition - we cannot safely complete the
clone operation.
+ // With the ref-counting optimization, cloning is fast and this
should be rare.
+ throw new SparkRuntimeException(
+ "INTERNAL_ERROR",
+ Map("message" -> (s"Cannot clone ArtifactManager: cached block
with hash $hash " +
+ s"was already released. The parent session may have been closed
during cloning.")),
+ e)
+ }
}
// Re-register resources to SparkContext
@@ -382,7 +409,6 @@ class ArtifactManager(session: SparkSession) extends
AutoCloseable with Logging
}
}
- newArtifactManager.cachedBlockIdList.addAll(newBlockIds.asJava)
newArtifactManager.jarsList.addAll(jarsList)
newArtifactManager.pythonIncludeList.addAll(pythonIncludeList)
newArtifactManager.sparkContextRelativePaths.addAll(sparkContextRelativePaths)
@@ -412,10 +438,16 @@ class ArtifactManager(session: SparkSession) extends
AutoCloseable with Logging
// Note that this will only be run once per instance.
cleanable.clean()
+ // Clean-up cached blocks.
+ val blockManager = session.sparkContext.env.blockManager
+ hashToCachedIdMap.values().forEach { refCountedCacheId =>
+ refCountedCacheId.release(blockManager)
+ }
+ hashToCachedIdMap.clear()
+
// Clean up internal trackers
jarsList.clear()
pythonIncludeList.clear()
- cachedBlockIdList.clear()
sparkContextRelativePaths.clear()
// Removed cached classloader
@@ -484,25 +516,6 @@ object ArtifactManager extends Logging {
val JAR, FILE, ARCHIVE = Value
}
- private def copyBlock(fromId: CacheId, toId: CacheId, blockManager:
BlockManager): CacheId = {
- require(fromId != toId)
- blockManager.getLocalBytes(fromId) match {
- case Some(blockData) =>
- Utils.tryWithSafeFinallyAndFailureCallbacks {
- val updater = blockManager.ByteBufferBlockStoreUpdater(
- blockId = toId,
- level = StorageLevel.MEMORY_AND_DISK_SER,
- classTag = implicitly[ClassTag[Array[Byte]]],
- bytes = blockData.toChunkedByteBuffer(ByteBuffer.allocate),
- tellMaster = false)
- updater.save()
- toId
- }(finallyBlock = { blockManager.releaseLock(fromId);
blockData.dispose() })
- case None =>
- throw SparkException.internalError(s"Block $fromId not found in the
block manager.")
- }
- }
-
// Shared cleaner instance
private val cleaner: Cleaner = Cleaner.create()
@@ -530,10 +543,6 @@ object ArtifactManager extends Logging {
}
}
- // Clean up cached relations
- val blockManager = sparkContext.env.blockManager
- blockManager.removeCache(sparkSessionUUID)
-
// Clean up artifacts folder
try {
Utils.deleteRecursively(artifactPath.toFile)
@@ -550,3 +559,28 @@ private[artifact] case class ArtifactStateForCleanup(
sparkContext: SparkContext,
jobArtifactState: JobArtifactState,
artifactPath: Path)
+
+private class RefCountedCacheId(val id: CacheId) {
+ private val rc = new AtomicInteger(1)
+
+ def acquire(): Unit = updateRc(1)
+
+ def release(blockManager: BlockManager): Unit = {
+ val newRc = updateRc(-1)
+ if (newRc == 0) {
+ blockManager.removeBlock(id)
+ }
+ }
+
+ private def updateRc(delta: Int): Int = {
+ rc.updateAndGet { currentRc: Int =>
+ if (currentRc == 0) {
+ throw new SparkRuntimeException(
+ "BLOCK_ALREADY_RELEASED",
+ Map("blockId" -> id.toString)
+ )
+ }
+ currentRc + delta
+ }
+ }
+}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/artifact/ArtifactManagerSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/artifact/ArtifactManagerSuite.scala
index f4a4ab012c2e..ead2d52edff3 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/artifact/ArtifactManagerSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/artifact/ArtifactManagerSuite.scala
@@ -503,15 +503,8 @@ class ArtifactManagerSuite extends SharedSparkSession {
assert(newArtifactManager.artifactPath !== artifactManager.artifactPath)
// Load the cached artifact
- val blockManager = newSession.sparkContext.env.blockManager
- for (sessionId <- Seq(spark.sessionUUID, newSession.sessionUUID)) {
- val cacheId = CacheId(sessionId, "test")
- try {
-
assert(blockManager.getLocalBytes(cacheId).get.toByteBuffer().array() ===
testBytes)
- } finally {
- blockManager.releaseLock(cacheId)
- }
- }
+ assert(spark.artifactManager.getCachedBlockId("test")
+ == newArtifactManager.getCachedBlockId("test"))
val allFiles = Utils.listFiles(newArtifactManager.artifactPath.toFile)
assert(allFiles.size() === 3)
@@ -540,6 +533,55 @@ class ArtifactManagerSuite extends SharedSparkSession {
}
}
+ test("Share blocks between ArtifactManagers") {
+ def isBlockRegistered(id: CacheId): Boolean = {
+ sparkContext.env.blockManager.getStatus(id).isDefined
+ }
+
+ def addCachedArtifact(session: SparkSession, name: String, data: String):
CacheId = {
+ val bytes = new Artifact.InMemory(data.getBytes(StandardCharsets.UTF_8))
+
session.artifactManager.addLocalArtifacts(Artifact.newCacheArtifact(name,
bytes) :: Nil)
+ val id = CacheId(session.sessionUUID, name)
+ assert(isBlockRegistered(id))
+ id
+ }
+
+ // Create fresh session so there is no interference with other tests.
+ val session1 = spark.newSession()
+ val b1 = addCachedArtifact(session1, "b1", "b_one")
+ val b2 = addCachedArtifact(session1, "b2", "b_two")
+
+ // Clone, check that existing blocks are the same, add another block,
clean-up, make sure
+ // shared blocks survive and new block is cleaned.
+ val session2 = session1.cloneSession()
+ val b3 = addCachedArtifact(session2, "b3", "b_three")
+ session2.artifactManager.cleanUpResourcesForTesting()
+ assert(isBlockRegistered(b1))
+ assert(isBlockRegistered(b2))
+ assert(!isBlockRegistered(b3))
+
+ // Clone, check that existing blocks are the same, replace existing
blocks, clone parent, check
+ // that inherited blocks are removed now.
+ val session3 = session1.cloneSession()
+ session1.artifactManager.cleanUpResourcesForTesting()
+ assert(isBlockRegistered(b1))
+ assert(isBlockRegistered(b2))
+ assert(session3.artifactManager.getCachedBlockId("b1").get == b1)
+ assert(session3.artifactManager.getCachedBlockId("b2").get == b2)
+
+ val b1a = addCachedArtifact(session3, "b1", "b_one_a")
+ val b2a = addCachedArtifact(session3, "b2", "b_two_a")
+ assert(!isBlockRegistered(b1))
+ assert(!isBlockRegistered(b2))
+ assert(session3.artifactManager.getCachedBlockId("b1").get == b1a)
+ assert(session3.artifactManager.getCachedBlockId("b2").get == b2a)
+
+ // Clean-up last AM. No block should be left.
+ session3.artifactManager.cleanUpResourcesForTesting()
+ assert(!isBlockRegistered(b1a))
+ assert(!isBlockRegistered(b2a))
+ }
+
test("Codegen cache should be invalid when artifacts are added - class
artifact") {
withTempDir { dir =>
runCodegenTest("class artifact") {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]