This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 59b8a4489c87 [SPARK-54001][SQL] Optimize memory usage in session 
cloning with ref-counted cached local relations
59b8a4489c87 is described below

commit 59b8a4489c878fa3a9aa6b7fbae760f2fc80eb9d
Author: pranavdev022 <[email protected]>
AuthorDate: Thu Oct 23 15:31:06 2025 -0400

    [SPARK-54001][SQL] Optimize memory usage in session cloning with 
ref-counted cached local relations
    
    ### What changes were proposed in this pull request?
    This PR optimizes memory management for cached local relations when cloning 
Spark sessions by implementing reference counting instead of data replication.
    
    **Current behavior:**
    - When a session is cloned, cached local relation data stored in the block 
manager is replicated.
    - Each clone creates a duplicate copy of the data with a new block ID.
    - This causes unnecessary memory pressure.
    
    **Proposed changes:**
    - Implement reference counting for cached local relations during session 
cloning.
    - Retain the same block ID and data reference when cloning sessions, 
incrementing a ref count instead of copying
    - Add a hash-to-blockId mapping in ArtifactManager for efficient block 
lookup
    - Clean up blocks from block manager memory when ref count reaches zero
    
    ### Why are the changes needed?
    Cloning sessions is a common operation in Spark applications (e.g., for 
creating isolated execution contexts). The current approach of duplicating 
cached data can significantly increase memory footprint, especially when:
    - Sessions are cloned frequently
    - Cached relations contain large datasets
    - Multiple clones exist simultaneously
    
    This optimization reduces memory pressure, improves performance by avoiding 
unnecessary data copies.
    
    ### Does this PR introduce _any_ user-facing change?
    No. This is an internal optimization that improves memory efficiency 
without changing user-facing APIs or behavior.
    
    ### How was this patch tested?
    - Added unit tests to verify the reference count logic functioning.
    - Existing unit tests for ArtifactManager and session cloning.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #52651 from pranavdev022/clone-artifactmanager-fix.
    
    Authored-by: pranavdev022 <[email protected]>
    Signed-off-by: Herman van Hovell <[email protected]>
---
 project/MimaExcludes.scala                         |   5 +-
 .../CheckConnectJvmClientCompatibility.scala       |   2 +
 .../sql/connect/planner/SparkConnectPlanner.scala  |   6 +-
 .../spark/sql/artifact/ArtifactManager.scala       | 102 ++++++++++++++-------
 .../spark/sql/artifact/ArtifactManagerSuite.scala  |  60 ++++++++++--
 5 files changed, 129 insertions(+), 46 deletions(-)

diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 6fba665c8b24..1614ec212c2e 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -45,7 +45,10 @@ object MimaExcludes {
     
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.collection.PrimitiveKeyOpenHashMap*"),
 
     // [SPARK-54041][SQL] Enable Direct Passthrough Partitioning in the 
DataFrame API
-    
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.Dataset.repartitionById")
+    
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.Dataset.repartitionById"),
+
+    // [SPARK-54001][CONNECT] Replace block copying with ref-counting in 
ArtifactManager cloning
+    
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.artifact.ArtifactManager.cachedBlockIdList")
   )
 
   // Default exclude rules
diff --git 
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
index cbaa4f5ea07f..92adc8eb9346 100644
--- 
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
+++ 
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
@@ -234,6 +234,8 @@ object CheckConnectJvmClientCompatibility {
         "org.apache.spark.sql.artifact.ArtifactManager$"),
       ProblemFilters.exclude[MissingClassProblem](
         
"org.apache.spark.sql.artifact.ArtifactManager$SparkContextResourceType$"),
+      ProblemFilters.exclude[MissingClassProblem](
+        "org.apache.spark.sql.artifact.RefCountedCacheId"),
 
       // ColumnNode conversions
       
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.SparkSession"),
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index a7a8f3506dea..ebcf462b84ce 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -60,7 +60,7 @@ import 
org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils}
 import org.apache.spark.sql.classic.{Catalog, DataFrameWriter, Dataset, 
MergeIntoWriter, RelationalGroupedDataset, SparkSession, TypedAggUtils, 
UserDefinedFunctionUtils}
 import org.apache.spark.sql.classic.ClassicConversions._
 import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
-import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, 
ForeachWriterPacket, LiteralValueProtoConverter, StorageLevelProtoConverter, 
StreamingListenerPacket, UdfPacket}
+import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, 
ForeachWriterPacket, InvalidPlanInput, LiteralValueProtoConverter, 
StorageLevelProtoConverter, StreamingListenerPacket, UdfPacket}
 import 
org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE
 import org.apache.spark.sql.connect.ml.MLHandler
 import org.apache.spark.sql.connect.pipelines.PipelinesHandler
@@ -1330,7 +1330,9 @@ class SparkConnectPlanner(
 
   private def transformCachedLocalRelation(rel: proto.CachedLocalRelation): 
LogicalPlan = {
     val blockManager = session.sparkContext.env.blockManager
-    val blockId = CacheId(sessionHolder.session.sessionUUID, rel.getHash)
+    val blockId = 
session.artifactManager.getCachedBlockId(rel.getHash).getOrElse {
+      throw InvalidPlanInput(s"Cannot find a cached local relation for hash: 
${rel.getHash}")
+    }
     val bytes = blockManager.getLocalBytes(blockId)
     bytes
       .map { blockData =>
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala
index de91e5e8a44b..5889fe581d4e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala
@@ -20,10 +20,9 @@ package org.apache.spark.sql.artifact
 import java.io.{File, IOException}
 import java.lang.ref.Cleaner
 import java.net.{URI, URL, URLClassLoader}
-import java.nio.ByteBuffer
 import java.nio.file.{CopyOption, Files, Path, Paths, StandardCopyOption}
-import java.util.concurrent.CopyOnWriteArrayList
-import java.util.concurrent.atomic.AtomicBoolean
+import java.util.concurrent.{ConcurrentHashMap, CopyOnWriteArrayList}
+import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger}
 
 import scala.collection.mutable.ListBuffer
 import scala.jdk.CollectionConverters._
@@ -114,7 +113,7 @@ class ArtifactManager(session: SparkSession) extends 
AutoCloseable with Logging
     }
   }
 
-  protected val cachedBlockIdList = new CopyOnWriteArrayList[CacheId]
+  private val hashToCachedIdMap = new ConcurrentHashMap[String, 
RefCountedCacheId]
   protected val jarsList = new CopyOnWriteArrayList[Path]
   protected val pythonIncludeList = new CopyOnWriteArrayList[String]
   protected val sparkContextRelativePaths =
@@ -136,6 +135,10 @@ class ArtifactManager(session: SparkSession) extends 
AutoCloseable with Logging
    */
   def getPythonIncludes: Seq[String] = pythonIncludeList.asScala.toSeq
 
+  protected[sql] def getCachedBlockId(hash: String): Option[CacheId] = {
+    Option(hashToCachedIdMap.get(hash)).map(_.id)
+  }
+
   private def transferFile(
       source: Path,
       target: Path,
@@ -192,7 +195,14 @@ class ArtifactManager(session: SparkSession) extends 
AutoCloseable with Logging
           blockSize = tmpFile.length(),
           tellMaster = false)
         updater.save()
-        cachedBlockIdList.add(blockId)
+        val oldBlock = hashToCachedIdMap.put(blockId.hash, new 
RefCountedCacheId(blockId))
+        if (oldBlock != null) {
+          logWarning(
+            log"Replacing existing cache artifact with hash 
${MDC(LogKeys.BLOCK_ID, blockId)} " +
+              log"in session ${MDC(LogKeys.SESSION_ID, session.sessionUUID)}. 
" +
+              log"This may indicate duplicate artifact addition.")
+          oldBlock.release(blockManager)
+        }
       }(finallyBlock = { tmpFile.delete() })
     } else if 
(normalizedRemoteRelativePath.startsWith(s"classes${File.separator}")) {
       // Move class files to the right directory.
@@ -354,10 +364,27 @@ class ArtifactManager(session: SparkSession) extends 
AutoCloseable with Logging
     if (artifactPath.toFile.exists()) {
       Utils.copyDirectory(artifactPath.toFile, 
newArtifactManager.artifactPath.toFile)
     }
-    val blockManager = sparkContext.env.blockManager
-    val newBlockIds = cachedBlockIdList.asScala.map { blockId =>
-      val newBlockId = blockId.copy(sessionUUID = newSession.sessionUUID)
-      copyBlock(blockId, newBlockId, blockManager)
+
+    // Share cached blocks with the cloned session by copying the references 
and incrementing
+    // their reference counts. Both the original and cloned ArtifactManager 
will reference the
+    // same underlying cached data blocks. When either session releases a 
block, only the ref-count
+    // decreases.
+    // The block is removed from memory only when the ref-count reaches zero.
+    hashToCachedIdMap.forEach { (hash: String, refCountedCacheId: 
RefCountedCacheId) =>
+      try {
+        refCountedCacheId.acquire()  // Increment ref-count to prevent 
premature cleanup
+        newArtifactManager.hashToCachedIdMap.put(hash, refCountedCacheId)
+      } catch {
+        case e: SparkRuntimeException if e.getCondition == 
"BLOCK_ALREADY_RELEASED" =>
+          // The parent session was closed or this block was released during 
cloning.
+          // This indicates a race condition - we cannot safely complete the 
clone operation.
+          // With the ref-counting optimization, cloning is fast and this 
should be rare.
+          throw new SparkRuntimeException(
+            "INTERNAL_ERROR",
+            Map("message" -> (s"Cannot clone ArtifactManager: cached block 
with hash $hash " +
+              s"was already released. The parent session may have been closed 
during cloning.")),
+            e)
+      }
     }
 
     // Re-register resources to SparkContext
@@ -382,7 +409,6 @@ class ArtifactManager(session: SparkSession) extends 
AutoCloseable with Logging
       }
     }
 
-    newArtifactManager.cachedBlockIdList.addAll(newBlockIds.asJava)
     newArtifactManager.jarsList.addAll(jarsList)
     newArtifactManager.pythonIncludeList.addAll(pythonIncludeList)
     
newArtifactManager.sparkContextRelativePaths.addAll(sparkContextRelativePaths)
@@ -412,10 +438,16 @@ class ArtifactManager(session: SparkSession) extends 
AutoCloseable with Logging
     // Note that this will only be run once per instance.
     cleanable.clean()
 
+    // Clean-up cached blocks.
+    val blockManager = session.sparkContext.env.blockManager
+    hashToCachedIdMap.values().forEach { refCountedCacheId =>
+      refCountedCacheId.release(blockManager)
+    }
+    hashToCachedIdMap.clear()
+
     // Clean up internal trackers
     jarsList.clear()
     pythonIncludeList.clear()
-    cachedBlockIdList.clear()
     sparkContextRelativePaths.clear()
 
     // Removed cached classloader
@@ -484,25 +516,6 @@ object ArtifactManager extends Logging {
     val JAR, FILE, ARCHIVE = Value
   }
 
-  private def copyBlock(fromId: CacheId, toId: CacheId, blockManager: 
BlockManager): CacheId = {
-    require(fromId != toId)
-    blockManager.getLocalBytes(fromId) match {
-      case Some(blockData) =>
-        Utils.tryWithSafeFinallyAndFailureCallbacks {
-          val updater = blockManager.ByteBufferBlockStoreUpdater(
-            blockId = toId,
-            level = StorageLevel.MEMORY_AND_DISK_SER,
-            classTag = implicitly[ClassTag[Array[Byte]]],
-            bytes = blockData.toChunkedByteBuffer(ByteBuffer.allocate),
-            tellMaster = false)
-          updater.save()
-          toId
-        }(finallyBlock = { blockManager.releaseLock(fromId); 
blockData.dispose() })
-      case None =>
-        throw SparkException.internalError(s"Block $fromId not found in the 
block manager.")
-    }
-  }
-
   // Shared cleaner instance
   private val cleaner: Cleaner = Cleaner.create()
 
@@ -530,10 +543,6 @@ object ArtifactManager extends Logging {
       }
     }
 
-    // Clean up cached relations
-    val blockManager = sparkContext.env.blockManager
-    blockManager.removeCache(sparkSessionUUID)
-
     // Clean up artifacts folder
     try {
       Utils.deleteRecursively(artifactPath.toFile)
@@ -550,3 +559,28 @@ private[artifact] case class ArtifactStateForCleanup(
   sparkContext: SparkContext,
   jobArtifactState: JobArtifactState,
   artifactPath: Path)
+
+private class RefCountedCacheId(val id: CacheId) {
+  private val rc = new AtomicInteger(1)
+
+  def acquire(): Unit = updateRc(1)
+
+  def release(blockManager: BlockManager): Unit = {
+    val newRc = updateRc(-1)
+    if (newRc == 0) {
+      blockManager.removeBlock(id)
+    }
+  }
+
+  private def updateRc(delta: Int): Int = {
+    rc.updateAndGet { currentRc: Int =>
+      if (currentRc == 0) {
+        throw new SparkRuntimeException(
+          "BLOCK_ALREADY_RELEASED",
+          Map("blockId" -> id.toString)
+        )
+      }
+      currentRc + delta
+    }
+  }
+}
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/artifact/ArtifactManagerSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/artifact/ArtifactManagerSuite.scala
index f4a4ab012c2e..ead2d52edff3 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/artifact/ArtifactManagerSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/artifact/ArtifactManagerSuite.scala
@@ -503,15 +503,8 @@ class ArtifactManagerSuite extends SharedSparkSession {
       assert(newArtifactManager.artifactPath !== artifactManager.artifactPath)
 
       // Load the cached artifact
-      val blockManager = newSession.sparkContext.env.blockManager
-      for (sessionId <- Seq(spark.sessionUUID, newSession.sessionUUID)) {
-        val cacheId = CacheId(sessionId, "test")
-        try {
-          
assert(blockManager.getLocalBytes(cacheId).get.toByteBuffer().array() === 
testBytes)
-        } finally {
-          blockManager.releaseLock(cacheId)
-        }
-      }
+      assert(spark.artifactManager.getCachedBlockId("test")
+        == newArtifactManager.getCachedBlockId("test"))
 
       val allFiles = Utils.listFiles(newArtifactManager.artifactPath.toFile)
       assert(allFiles.size() === 3)
@@ -540,6 +533,55 @@ class ArtifactManagerSuite extends SharedSparkSession {
     }
   }
 
+  test("Share blocks between ArtifactManagers") {
+    def isBlockRegistered(id: CacheId): Boolean = {
+      sparkContext.env.blockManager.getStatus(id).isDefined
+    }
+
+    def addCachedArtifact(session: SparkSession, name: String, data: String): 
CacheId = {
+      val bytes = new Artifact.InMemory(data.getBytes(StandardCharsets.UTF_8))
+      
session.artifactManager.addLocalArtifacts(Artifact.newCacheArtifact(name, 
bytes) :: Nil)
+      val id = CacheId(session.sessionUUID, name)
+      assert(isBlockRegistered(id))
+      id
+    }
+
+    // Create fresh session so there is no interference with other tests.
+    val session1 = spark.newSession()
+    val b1 = addCachedArtifact(session1, "b1", "b_one")
+    val b2 = addCachedArtifact(session1, "b2", "b_two")
+
+    // Clone, check that existing blocks are the same, add another block, 
clean-up, make sure
+    // shared blocks survive and new block is cleaned.
+    val session2 = session1.cloneSession()
+    val b3 = addCachedArtifact(session2, "b3", "b_three")
+    session2.artifactManager.cleanUpResourcesForTesting()
+    assert(isBlockRegistered(b1))
+    assert(isBlockRegistered(b2))
+    assert(!isBlockRegistered(b3))
+
+    // Clone, check that existing blocks are the same, replace existing 
blocks, clone parent, check
+    // that inherited blocks are removed now.
+    val session3 = session1.cloneSession()
+    session1.artifactManager.cleanUpResourcesForTesting()
+    assert(isBlockRegistered(b1))
+    assert(isBlockRegistered(b2))
+    assert(session3.artifactManager.getCachedBlockId("b1").get == b1)
+    assert(session3.artifactManager.getCachedBlockId("b2").get == b2)
+
+    val b1a = addCachedArtifact(session3, "b1", "b_one_a")
+    val b2a = addCachedArtifact(session3, "b2", "b_two_a")
+    assert(!isBlockRegistered(b1))
+    assert(!isBlockRegistered(b2))
+    assert(session3.artifactManager.getCachedBlockId("b1").get == b1a)
+    assert(session3.artifactManager.getCachedBlockId("b2").get == b2a)
+
+    // Clean-up last AM. No block should be left.
+    session3.artifactManager.cleanUpResourcesForTesting()
+    assert(!isBlockRegistered(b1a))
+    assert(!isBlockRegistered(b2a))
+  }
+
   test("Codegen cache should be invalid when artifacts are added - class 
artifact") {
     withTempDir { dir =>
       runCodegenTest("class artifact") {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to