This is an automated email from the ASF dual-hosted git repository.

ashrigondekar pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a75e1aa1cac2 [SPARK-53001] Integrate RocksDB Memory Usage with the 
Unified Memory Manager
a75e1aa1cac2 is described below

commit a75e1aa1cac2d4f135542788ff8b5246408641cf
Author: Eric Marnadi <eric.marn...@databricks.com>
AuthorDate: Thu Jul 31 17:24:52 2025 -0700

    [SPARK-53001] Integrate RocksDB Memory Usage with the Unified Memory Manager
    
    ### What changes were proposed in this pull request?
    
    Currently, RocksDB memory is untracked and not included in memory decisions 
in Spark (particularly when Photon is enabled). We want to factor the RocksDB 
memory usage into memory allocations so we don't hit OOMs. This change 
introduces a background memory polling thread from the MemoryManager that 
queries RocksDB memory every X seconds (configurable via SQLConf).
    
    ### Why are the changes needed?
    
    This helps us avoid OOMs when RocksDB is used as the StateStoreProvider by 
taking other Spark allocations into account.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Unit tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #51708 from ericm-db/rocksdb-mm.
    
    Authored-by: Eric Marnadi <eric.marn...@databricks.com>
    Signed-off-by: Anish Shrigondekar <anish.shrigonde...@databricks.com>
---
 .../org/apache/spark/internal/config/package.scala |  10 +
 .../apache/spark/memory/UnifiedMemoryManager.scala | 305 ++++++++++++++++++++-
 .../spark/memory/UnifiedMemoryManagerSuite.scala   | 284 +++++++++++++++++++
 .../sql/execution/streaming/state/RocksDB.scala    | 107 +++++++-
 .../streaming/state/RocksDBMemoryManager.scala     |  74 ++++-
 .../state/RocksDBStateStoreProvider.scala          |   9 +-
 .../FailureInjectionCheckpointFileManager.scala    |  12 +-
 .../RocksDBCheckpointFailureInjectionSuite.scala   |   3 +-
 .../state/RocksDBStateStoreIntegrationSuite.scala  |  78 +++++-
 9 files changed, 865 insertions(+), 17 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala 
b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index 2c7c2f120b93..0be2a53d7a0f 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -500,6 +500,16 @@ package object config {
     .doubleConf
     .createWithDefault(0.6)
 
+  private[spark] val UNMANAGED_MEMORY_POLLING_INTERVAL =
+    ConfigBuilder("spark.memory.unmanagedMemoryPollingInterval")
+      .doc("Interval for polling unmanaged memory users to track their memory 
usage. " +
+        "Unmanaged memory users are components that manage their own memory 
outside of " +
+        "Spark's core memory management, such as RocksDB for Streaming State 
Store. " +
+        "Setting this to 0 disables unmanaged memory polling.")
+      .version("4.1.0")
+      .timeConf(TimeUnit.MILLISECONDS)
+      .createWithDefaultString("1s")
+
   private[spark] val STORAGE_UNROLL_MEMORY_THRESHOLD =
     ConfigBuilder("spark.storage.unrollMemoryThreshold")
       .doc("Initial memory to request before unrolling any block")
diff --git 
a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala 
b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala
index d4ec6ed8495a..0aec2c232aab 100644
--- a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala
+++ b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala
@@ -17,11 +17,19 @@
 
 package org.apache.spark.memory
 
+import java.util.concurrent.{ConcurrentHashMap, ScheduledExecutorService, 
TimeUnit}
+import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong}
+
+import scala.jdk.CollectionConverters._
+import scala.util.control.NonFatal
+
 import org.apache.spark.{SparkConf, SparkIllegalArgumentException}
-import org.apache.spark.internal.{config, MDC}
+import org.apache.spark.internal.{config, Logging, LogKeys, MDC}
 import org.apache.spark.internal.LogKeys._
 import org.apache.spark.internal.config.Tests._
+import org.apache.spark.internal.config.UNMANAGED_MEMORY_POLLING_INTERVAL
 import org.apache.spark.storage.BlockId
+import org.apache.spark.util.{ThreadUtils, Utils}
 
 /**
  * A [[MemoryManager]] that enforces a soft boundary between execution and 
storage such that
@@ -56,7 +64,47 @@ private[spark] class UnifiedMemoryManager(
     conf,
     numCores,
     onHeapStorageRegionSize,
-    maxHeapMemory - onHeapStorageRegionSize) {
+    maxHeapMemory - onHeapStorageRegionSize) with Logging  {
+
+  /**
+   * Unmanaged memory tracking infrastructure.
+   *
+   * Unmanaged memory refers to memory consumed by components that manage 
their own memory
+   * outside of Spark's unified memory management system. Examples include:
+   * - RocksDB state stores used in structured streaming
+   * - Native libraries with their own memory management
+   * - Off-heap caches managed by unmanaged systems
+   *
+   * We track this memory to:
+   * 1. Provide visibility into total memory usage on executors
+   * 2. Prevent OOM errors by accounting for it in memory allocation decisions
+   * 3. Enable better debugging and monitoring of memory-intensive applications
+   *
+   * The polling mechanism periodically queries registered unmanaged memory 
consumers
+   * to detect inactive consumers and handle cleanup.
+   */
+  // Configuration for polling interval (in milliseconds)
+  private val unmanagedMemoryPollingIntervalMs = 
conf.get(UNMANAGED_MEMORY_POLLING_INTERVAL)
+  // Initialize background polling if enabled
+  if (unmanagedMemoryPollingIntervalMs > 0) {
+    UnifiedMemoryManager.startPollingIfNeeded(unmanagedMemoryPollingIntervalMs)
+  }
+
+  /**
+   * Get the current unmanaged memory usage in bytes for a specific memory 
mode.
+   * @param memoryMode The memory mode (ON_HEAP or OFF_HEAP) to get usage for
+   * @return The current unmanaged memory usage in bytes
+   */
+  private def getUnmanagedMemoryUsed(memoryMode: MemoryMode): Long = {
+    // Only consider unmanaged memory if polling is enabled
+    if (unmanagedMemoryPollingIntervalMs <= 0) {
+      return 0L
+    }
+    memoryMode match {
+      case MemoryMode.ON_HEAP => UnifiedMemoryManager.unmanagedOnHeapUsed.get()
+      case MemoryMode.OFF_HEAP => 
UnifiedMemoryManager.unmanagedOffHeapUsed.get()
+    }
+  }
 
   private def assertInvariants(): Unit = {
     assert(onHeapExecutionMemoryPool.poolSize + 
onHeapStorageMemoryPool.poolSize == maxHeapMemory)
@@ -140,9 +188,15 @@ private[spark] class UnifiedMemoryManager(
      * in execution memory allocation across tasks, Otherwise, a task may 
occupy more than
      * its fair share of execution memory, mistakenly thinking that other 
tasks can acquire
      * the portion of storage memory that cannot be evicted.
+     *
+     * This also factors in unmanaged memory usage to ensure we don't 
over-allocate memory
+     * when unmanaged components are consuming significant memory.
      */
     def computeMaxExecutionPoolSize(): Long = {
-      maxMemory - math.min(storagePool.memoryUsed, storageRegionSize)
+      val unmanagedMemory = getUnmanagedMemoryUsed(memoryMode)
+      val availableMemory = maxMemory - math.min(storagePool.memoryUsed, 
storageRegionSize)
+      // Reduce available memory by unmanaged memory usage to prevent 
over-allocation
+      math.max(0L, availableMemory - unmanagedMemory)
     }
 
     executionPool.acquireMemory(
@@ -165,11 +219,21 @@ private[spark] class UnifiedMemoryManager(
         offHeapStorageMemoryPool,
         maxOffHeapStorageMemory)
     }
-    if (numBytes > maxMemory) {
+
+    // Factor in unmanaged memory usage for the specific memory mode
+    val unmanagedMemory = getUnmanagedMemoryUsed(memoryMode)
+    val effectiveMaxMemory = math.max(0L, maxMemory - unmanagedMemory)
+
+    if (numBytes > effectiveMaxMemory) {
       // Fail fast if the block simply won't fit
       logInfo(log"Will not store ${MDC(BLOCK_ID, blockId)} as the required 
space" +
         log" (${MDC(NUM_BYTES, numBytes)} bytes) exceeds our" +
-        log" memory limit (${MDC(NUM_BYTES_MAX, maxMemory)} bytes)")
+        log" memory limit (${MDC(NUM_BYTES_MAX, effectiveMaxMemory)} bytes)" +
+        (if (unmanagedMemory > 0) {
+          log" (unmanaged memory usage: ${MDC(NUM_BYTES, unmanagedMemory)} 
bytes)"
+        } else {
+          log""
+        }))
       return false
     }
     if (numBytes > storagePool.memoryFree) {
@@ -191,7 +255,7 @@ private[spark] class UnifiedMemoryManager(
   }
 }
 
-object UnifiedMemoryManager {
+object UnifiedMemoryManager extends Logging {
 
   // Set aside a fixed amount of memory for non-storage, non-execution 
purposes.
   // This serves a function similar to `spark.memory.fraction`, but guarantees 
that we reserve
@@ -199,6 +263,181 @@ object UnifiedMemoryManager {
   // the memory used for execution and storage will be (1024 - 300) * 0.6 = 
434MB by default.
   private val RESERVED_SYSTEM_MEMORY_BYTES = 300 * 1024 * 1024
 
+  private val unmanagedMemoryConsumers =
+    new ConcurrentHashMap[UnmanagedMemoryConsumerId, UnmanagedMemoryConsumer]
+
+  // Cached unmanaged memory usage values updated by polling
+  private val unmanagedOnHeapUsed = new AtomicLong(0L)
+  private val unmanagedOffHeapUsed = new AtomicLong(0L)
+
+  // Atomic flag to ensure polling is only started once per JVM
+  private val pollingStarted = new AtomicBoolean(false)
+
+  /**
+   * Register an unmanaged memory consumer to track its memory usage.
+   *
+   * Unmanaged memory consumers are components that manage their own memory 
outside
+   * of Spark's unified memory management system. By registering, their memory 
usage
+   * will be periodically polled and factored into Spark's memory allocation 
decisions.
+   *
+   * @param unmanagedMemoryConsumer The consumer to register for memory 
tracking
+   */
+  def registerUnmanagedMemoryConsumer(
+                                       unmanagedMemoryConsumer: 
UnmanagedMemoryConsumer): Unit = {
+    val id = unmanagedMemoryConsumer.unmanagedMemoryConsumerId
+    unmanagedMemoryConsumers.put(id, unmanagedMemoryConsumer)
+  }
+
+  /**
+   * Unregister an unmanaged memory consumer.
+   * This should be called when a component is shutting down to prevent memory 
leaks
+   * and ensure accurate memory tracking.
+   *
+   * @param unmanagedMemoryConsumer The consumer to unregister. Only used in 
tests
+   */
+  private[spark] def unregisterUnmanagedMemoryConsumer(
+      unmanagedMemoryConsumer: UnmanagedMemoryConsumer): Unit = {
+    val id = unmanagedMemoryConsumer.unmanagedMemoryConsumerId
+    unmanagedMemoryConsumers.remove(id)
+  }
+
+
+  /**
+   * Get the current memory usage in bytes for a specific component type.
+   * @param componentType The type of component to filter by (e.g., "RocksDB")
+   * @return Total memory usage in bytes for the specified component type
+   */
+  def getMemoryByComponentType(componentType: String): Long = {
+    unmanagedMemoryConsumers.asScala.values.toSeq
+      .filter(_.unmanagedMemoryConsumerId.componentType == componentType)
+      .map { memoryUser =>
+        try {
+          memoryUser.getMemBytesUsed
+        } catch {
+          case e: Exception =>
+            0L
+        }
+      }
+      .sum
+  }
+
+  /**
+   * Clear all unmanaged memory users.
+   * This is useful during executor shutdown or cleanup.
+   * Since each executor runs in its own JVM, this clears all users for this 
executor.
+   */
+  def clearUnmanagedMemoryUsers(): Unit = {
+    unmanagedMemoryConsumers.clear()
+    // Reset cached values when clearing consumers
+    unmanagedOnHeapUsed.set(0L)
+    unmanagedOffHeapUsed.set(0L)
+  }
+
+  // Shared polling infrastructure - only one polling thread per JVM
+  @volatile private var unmanagedMemoryPoller: ScheduledExecutorService = _
+
+  /**
+   * Start unmanaged memory polling if not already started.
+   * This ensures only one polling thread is created per JVM, regardless of 
how many
+   * UnifiedMemoryManager instances are created.
+   */
+  private[memory] def startPollingIfNeeded(pollingIntervalMs: Long): Unit = {
+    if (pollingStarted.compareAndSet(false, true)) {
+      unmanagedMemoryPoller = 
ThreadUtils.newDaemonSingleThreadScheduledExecutor(
+        "unmanaged-memory-poller")
+
+      val pollingTask = new Runnable {
+        override def run(): Unit = Utils.tryLogNonFatalError {
+          pollUnmanagedMemoryUsers()
+        }
+      }
+
+      unmanagedMemoryPoller.scheduleAtFixedRate(
+        pollingTask,
+        0L, // initial delay
+        pollingIntervalMs,
+        TimeUnit.MILLISECONDS)
+
+      logInfo(log"Unmanaged memory polling started with interval " +
+        log"${MDC(LogKeys.TIME, pollingIntervalMs)}ms")
+    }
+  }
+
+  private def pollUnmanagedMemoryUsers(): Unit = {
+    val consumers = unmanagedMemoryConsumers.asScala.toMap
+
+    // Get memory usage for each consumer, handling failures gracefully
+    val memoryUsages = consumers.map { case (userId, memoryUser) =>
+      try {
+        val memoryUsed = memoryUser.getMemBytesUsed
+        if (memoryUsed == -1L) {
+          logDebug(log"Unmanaged memory consumer ${MDC(LogKeys.OBJECT_ID, 
userId.toString)} " +
+            log"is no longer active, marking for removal")
+          (userId, memoryUser, None) // Mark for removal
+        } else if (memoryUsed < 0L) {
+          logWarning(log"Invalid memory usage value ${MDC(LogKeys.NUM_BYTES, 
memoryUsed)} " +
+            log"from unmanaged memory user ${MDC(LogKeys.OBJECT_ID, 
userId.toString)}")
+          (userId, memoryUser, Some(0L)) // Treat as 0
+        } else {
+          (userId, memoryUser, Some(memoryUsed))
+        }
+      } catch {
+        case NonFatal(e) =>
+          logWarning(log"Failed to get memory usage for unmanaged memory user 
" +
+            log"${MDC(LogKeys.OBJECT_ID, userId.toString)} 
${MDC(LogKeys.EXCEPTION, e)}")
+          (userId, memoryUser, Some(0L)) // Treat as 0 on error
+      }
+    }
+
+    // Remove inactive consumers
+    memoryUsages.filter(_._3.isEmpty).foreach { case (userId, _, _) =>
+      unmanagedMemoryConsumers.remove(userId)
+      logInfo(log"Removed inactive unmanaged memory consumer " +
+        log"${MDC(LogKeys.OBJECT_ID, userId.toString)}")
+    }
+    // Calculate total memory usage by mode
+    val activeUsages = memoryUsages.filter(_._3.isDefined)
+    val onHeapTotal = activeUsages
+      .filter(_._2.memoryMode == MemoryMode.ON_HEAP)
+      .map(_._3.get)
+      .sum
+    val offHeapTotal = activeUsages
+      .filter(_._2.memoryMode == MemoryMode.OFF_HEAP)
+      .map(_._3.get)
+      .sum
+    // Update cached values atomically
+    unmanagedOnHeapUsed.set(onHeapTotal)
+    unmanagedOffHeapUsed.set(offHeapTotal)
+    // Log polling results for monitoring
+    val totalMemoryUsed = onHeapTotal + offHeapTotal
+    val numConsumers = activeUsages.size
+    logDebug(s"Unmanaged memory polling completed: $numConsumers consumers, " +
+      s"total memory used: ${totalMemoryUsed} bytes " +
+      s"(on-heap: ${onHeapTotal}, off-heap: ${offHeapTotal})")
+  }
+
+  /**
+   * Shutdown the unmanaged memory polling thread. Only used in tests
+   */
+  private[spark] def shutdownUnmanagedMemoryPoller(): Unit = {
+    synchronized {
+      if (unmanagedMemoryPoller != null) {
+        unmanagedMemoryPoller.shutdown()
+        try {
+          if (!unmanagedMemoryPoller.awaitTermination(5, TimeUnit.SECONDS)) {
+            unmanagedMemoryPoller.shutdownNow()
+          }
+        } catch {
+          case _: InterruptedException =>
+            Thread.currentThread().interrupt()
+        }
+        unmanagedMemoryPoller = null
+        pollingStarted.set(false)
+        logInfo(log"Unmanaged memory poller shutdown complete")
+      }
+    }
+  }
+
   def apply(conf: SparkConf, numCores: Int): UnifiedMemoryManager = {
     val maxMemory = getMaxMemory(conf)
     new UnifiedMemoryManager(
@@ -242,3 +481,57 @@ object UnifiedMemoryManager {
     (usableMemory * memoryFraction).toLong
   }
 }
+
+/**
+ * Identifier for an unmanaged memory consumer.
+ *
+ * @param componentType The type of component (e.g., "RocksDB", 
"NativeLibrary")
+ * @param instanceKey A unique key to identify this specific instance of the 
component.
+ *                    For shared memory consumers, this should be a common key 
across
+ *                    all instances to avoid double counting.
+ */
+case class UnmanagedMemoryConsumerId(
+                                      componentType: String,
+                                      instanceKey: String
+                                    )
+
+/**
+ * Interface for components that consume memory outside of Spark's unified 
memory management.
+ *
+ * Components implementing this trait can register themselves with the memory 
manager
+ * to have their memory usage tracked and factored into memory allocation 
decisions.
+ * This helps prevent OOM errors when unmanaged components use significant 
memory.
+ *
+ * Examples of unmanaged memory consumers:
+ * - RocksDB state stores in structured streaming
+ * - Native libraries with custom memory allocation
+ * - Off-heap caches managed outside of Spark
+ */
+trait UnmanagedMemoryConsumer {
+  /**
+   * Returns the unique identifier for this memory consumer.
+   * The identifier is used to track and manage the consumer in the memory 
tracking system.
+   */
+  def unmanagedMemoryConsumerId: UnmanagedMemoryConsumerId
+
+  /**
+   * Returns the memory mode (ON_HEAP or OFF_HEAP) that this consumer uses.
+   * This is used to ensure unmanaged memory usage only affects the correct 
memory pool.
+   */
+  def memoryMode: MemoryMode
+
+  /**
+   * Returns the current memory usage in bytes.
+   *
+   * This method is called periodically by the memory polling mechanism to 
track
+   * memory usage over time. Implementations should return the current total 
memory
+   * consumed by this component.
+   *
+   * @return Current memory usage in bytes. Should return 0 if no memory is 
currently used.
+   *         Return -1L to indicate this consumer is no longer active and 
should be
+   *         automatically removed from tracking.
+   * @throws Exception if memory usage cannot be determined. The polling 
mechanism
+   *                   will handle exceptions gracefully and log warnings.
+   */
+  def getMemBytesUsed: Long
+}
diff --git 
a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala 
b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala
index 0cafe6891c7d..9c74f2fdd459 100644
--- 
a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala
+++ 
b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala
@@ -340,5 +340,289 @@ class UnifiedMemoryManagerSuite extends 
MemoryManagerSuite with PrivateMethodTes
     assert(mm.acquireStorageMemory(dummyBlock, 100L, memoryMode))
     assertEvictBlocksToFreeSpaceCalled(ms, 50)
     assert(mm.storageMemoryUsed === 600L)
+    UnifiedMemoryManager.shutdownUnmanagedMemoryPoller()
+  }
+
+  test("unmanaged memory tracking with memory mode separation") {
+    val maxMemory = 1000L
+    val taskAttemptId = 0L
+    val conf = new SparkConf()
+      .set(MEMORY_FRACTION, 1.0)
+      .set(TEST_MEMORY, maxMemory)
+      .set(MEMORY_OFFHEAP_ENABLED, false)
+      .set(MEMORY_STORAGE_FRACTION, storageFraction)
+      .set(UNMANAGED_MEMORY_POLLING_INTERVAL, 100L) // 100ms polling
+    val mm = UnifiedMemoryManager(conf, numCores = 1)
+    val memoryMode = MemoryMode.ON_HEAP
+
+    // Mock unmanaged memory consumer for ON_HEAP
+    class MockOnHeapMemoryConsumer(var memoryUsed: Long) extends 
UnmanagedMemoryConsumer {
+      override def unmanagedMemoryConsumerId: UnmanagedMemoryConsumerId =
+        UnmanagedMemoryConsumerId("TestOnHeap", "test-instance")
+      override def memoryMode: MemoryMode = MemoryMode.ON_HEAP
+      override def getMemBytesUsed: Long = memoryUsed
+    }
+
+    // Mock unmanaged memory consumer for OFF_HEAP
+    class MockOffHeapMemoryConsumer(var memoryUsed: Long) extends 
UnmanagedMemoryConsumer {
+      override def unmanagedMemoryConsumerId: UnmanagedMemoryConsumerId =
+        UnmanagedMemoryConsumerId("TestOffHeap", "test-instance")
+      override def memoryMode: MemoryMode = MemoryMode.OFF_HEAP
+      override def getMemBytesUsed: Long = memoryUsed
+    }
+
+    val onHeapConsumer = new MockOnHeapMemoryConsumer(0L)
+    val offHeapConsumer = new MockOffHeapMemoryConsumer(0L)
+
+    try {
+      // Register both consumers
+      UnifiedMemoryManager.registerUnmanagedMemoryConsumer(onHeapConsumer)
+      UnifiedMemoryManager.registerUnmanagedMemoryConsumer(offHeapConsumer)
+
+      // Initially no unmanaged memory usage
+      assert(UnifiedMemoryManager.getMemoryByComponentType("TestOnHeap") === 
0L)
+      assert(UnifiedMemoryManager.getMemoryByComponentType("TestOffHeap") === 
0L)
+
+      // Set off-heap memory usage - this should NOT affect on-heap allocations
+      offHeapConsumer.memoryUsed = 200L
+
+      // Wait for polling to pick up the change
+      Thread.sleep(200)
+
+      // Test that off-heap unmanaged memory doesn't affect on-heap execution 
memory allocation
+      val acquiredMemory = mm.acquireExecutionMemory(1000L, taskAttemptId, 
memoryMode)
+      // Should get full 1000 bytes since off-heap unmanaged memory doesn't 
affect on-heap pool
+      assert(acquiredMemory == 1000L)
+
+      // Release execution memory
+      mm.releaseExecutionMemory(acquiredMemory, taskAttemptId, memoryMode)
+
+      // Now set on-heap memory usage - this SHOULD affect on-heap allocations
+      onHeapConsumer.memoryUsed = 200L
+      Thread.sleep(200)
+
+      // Test that on-heap unmanaged memory affects on-heap execution memory 
allocation
+      val acquiredMemory2 = mm.acquireExecutionMemory(900L, taskAttemptId, 
memoryMode)
+      // Should only get 800 bytes due to 200 bytes of on-heap unmanaged 
memory usage
+      assert(acquiredMemory2 == 800L)
+
+      // Release execution memory to test storage allocation
+      mm.releaseExecutionMemory(acquiredMemory2, taskAttemptId, memoryMode)
+
+      // Test storage memory with on-heap unmanaged memory consideration
+      onHeapConsumer.memoryUsed = 300L
+      Thread.sleep(200)
+
+      // Storage should fail when block size + unmanaged memory > max memory
+      assert(!mm.acquireStorageMemory(dummyBlock, 800L, memoryMode))
+
+      // But smaller storage requests should succeed with unmanaged memory 
factored in
+      // With 300L on-heap unmanaged memory, effective max is 700L
+      assert(mm.acquireStorageMemory(dummyBlock, 600L, memoryMode))
+
+    } finally {
+      UnifiedMemoryManager.shutdownUnmanagedMemoryPoller()
+      UnifiedMemoryManager.clearUnmanagedMemoryUsers()
+    }
+  }
+
+  test("unmanaged memory consumer registration and unregistration") {
+    val conf = new SparkConf()
+      .set(MEMORY_FRACTION, 1.0)
+      .set(TEST_MEMORY, 1000L)
+      .set(MEMORY_OFFHEAP_ENABLED, false)
+      .set(UNMANAGED_MEMORY_POLLING_INTERVAL, 100L)
+
+    val mm = UnifiedMemoryManager(conf, numCores = 1)
+
+    class MockMemoryConsumer(
+        var memoryUsed: Long,
+        instanceId: String,
+        mode: MemoryMode = MemoryMode.ON_HEAP) extends UnmanagedMemoryConsumer 
{
+      override def unmanagedMemoryConsumerId: UnmanagedMemoryConsumerId =
+        UnmanagedMemoryConsumerId("Test", instanceId)
+      override def memoryMode: MemoryMode = mode
+      override def getMemBytesUsed: Long = memoryUsed
+    }
+
+    val consumer1 = new MockMemoryConsumer(100L, "test-instance-1")
+    val consumer2 = new MockMemoryConsumer(200L, "test-instance-2")
+
+    try {
+      // Register consumers
+      UnifiedMemoryManager.registerUnmanagedMemoryConsumer(consumer1)
+      UnifiedMemoryManager.registerUnmanagedMemoryConsumer(consumer2)
+
+      Thread.sleep(200)
+      assert(UnifiedMemoryManager.getMemoryByComponentType("Test") === 300L)
+
+      // Unregister one consumer
+      UnifiedMemoryManager.unregisterUnmanagedMemoryConsumer(consumer1)
+
+      Thread.sleep(200)
+      assert(UnifiedMemoryManager.getMemoryByComponentType("Test") === 200L)
+
+      // Unregister second consumer
+      UnifiedMemoryManager.unregisterUnmanagedMemoryConsumer(consumer2)
+
+      Thread.sleep(200)
+      assert(UnifiedMemoryManager.getMemoryByComponentType("Test") === 0L)
+
+    } finally {
+      UnifiedMemoryManager.shutdownUnmanagedMemoryPoller()
+      UnifiedMemoryManager.clearUnmanagedMemoryUsers()
+    }
+  }
+
+  test("unmanaged memory consumer auto-removal when returning -1") {
+    val conf = new SparkConf()
+      .set(MEMORY_FRACTION, 1.0)
+      .set(TEST_MEMORY, 1000L)
+      .set(MEMORY_OFFHEAP_ENABLED, false)
+      .set(UNMANAGED_MEMORY_POLLING_INTERVAL, 100L)
+
+    val mm = UnifiedMemoryManager(conf, numCores = 1)
+
+    class MockMemoryConsumer(var memoryUsed: Long) extends 
UnmanagedMemoryConsumer {
+      override def unmanagedMemoryConsumerId: UnmanagedMemoryConsumerId =
+        UnmanagedMemoryConsumerId("Test", s"test-instance-${this.hashCode()}")
+      override def memoryMode: MemoryMode = MemoryMode.ON_HEAP
+      override def getMemBytesUsed: Long = memoryUsed
+    }
+
+    val consumer1 = new MockMemoryConsumer(100L)
+    val consumer2 = new MockMemoryConsumer(200L)
+
+    try {
+      // Register consumers
+      UnifiedMemoryManager.registerUnmanagedMemoryConsumer(consumer1)
+      UnifiedMemoryManager.registerUnmanagedMemoryConsumer(consumer2)
+
+      Thread.sleep(200)
+      assert(UnifiedMemoryManager.getMemoryByComponentType("Test") === 300L)
+
+      // Mark consumer1 as inactive
+      consumer1.memoryUsed = -1L
+
+      // Wait for polling to detect and remove the inactive consumer
+      Thread.sleep(200)
+      assert(UnifiedMemoryManager.getMemoryByComponentType("Test") === 200L)
+
+      // Mark consumer2 as inactive as well
+      consumer2.memoryUsed = -1L
+
+      Thread.sleep(200)
+      assert(UnifiedMemoryManager.getMemoryByComponentType("Test") === 0L)
+
+    } finally {
+      UnifiedMemoryManager.shutdownUnmanagedMemoryPoller()
+      UnifiedMemoryManager.clearUnmanagedMemoryUsers()
+    }
+  }
+
+  test("unmanaged memory polling disabled when interval is zero") {
+    val conf = new SparkConf()
+      .set(MEMORY_FRACTION, 1.0)
+      .set(TEST_MEMORY, 1000L)
+      .set(MEMORY_OFFHEAP_ENABLED, false)
+      .set(MEMORY_STORAGE_FRACTION, storageFraction)
+      .set(UNMANAGED_MEMORY_POLLING_INTERVAL, 0L) // Disabled
+
+    val mm = UnifiedMemoryManager(conf, numCores = 1)
+
+    // When polling is disabled, unmanaged memory should not affect allocations
+    class MockUnmanagedMemoryConsumer(var memoryUsed: Long) extends 
UnmanagedMemoryConsumer {
+      override def unmanagedMemoryConsumerId: UnmanagedMemoryConsumerId =
+        UnmanagedMemoryConsumerId("Test", "test-instance")
+      override def memoryMode: MemoryMode = MemoryMode.ON_HEAP
+      override def getMemBytesUsed: Long = memoryUsed
+    }
+
+    val consumer = new MockUnmanagedMemoryConsumer(500L)
+
+    try {
+      UnifiedMemoryManager.registerUnmanagedMemoryConsumer(consumer)
+
+      // Since polling is disabled, should be able to allocate full memory
+      val acquiredMemory = mm.acquireExecutionMemory(1000L, 0L, 
MemoryMode.ON_HEAP)
+      assert(acquiredMemory === 1000L)
+
+    } finally {
+      UnifiedMemoryManager.shutdownUnmanagedMemoryPoller()
+      UnifiedMemoryManager.clearUnmanagedMemoryUsers()
+    }
+  }
+
+  test("unmanaged memory tracking with off-heap memory enabled") {
+    val maxOnHeapMemory = 1000L
+    val maxOffHeapMemory = 1500L
+    val taskAttemptId = 0L
+    val conf = new SparkConf()
+      .set(MEMORY_FRACTION, 1.0)
+      .set(TEST_MEMORY, maxOnHeapMemory)
+      .set(MEMORY_OFFHEAP_ENABLED, true)
+      .set(MEMORY_OFFHEAP_SIZE, maxOffHeapMemory)
+      .set(MEMORY_STORAGE_FRACTION, storageFraction)
+      .set(UNMANAGED_MEMORY_POLLING_INTERVAL, 100L)
+    val mm = UnifiedMemoryManager(conf, numCores = 1)
+
+    // Mock unmanaged memory consumer
+    class MockUnmanagedMemoryConsumer(var memoryUsed: Long) extends 
UnmanagedMemoryConsumer {
+      override def unmanagedMemoryConsumerId: UnmanagedMemoryConsumerId =
+        UnmanagedMemoryConsumerId("ExternalLib", "test-instance")
+
+      override def memoryMode: MemoryMode = MemoryMode.OFF_HEAP
+
+      override def getMemBytesUsed: Long = memoryUsed
+    }
+
+    val unmanagedConsumer = new MockUnmanagedMemoryConsumer(0L)
+
+    try {
+      // Register the unmanaged memory consumer
+      UnifiedMemoryManager.registerUnmanagedMemoryConsumer(unmanagedConsumer)
+
+      // Test off-heap memory allocation with unmanaged memory
+      unmanagedConsumer.memoryUsed = 300L
+      Thread.sleep(200)
+
+      // Test off-heap execution memory
+      // With 300 bytes of unmanaged memory, effective off-heap memory should 
be reduced
+      val offHeapAcquired = mm.acquireExecutionMemory(1400L, taskAttemptId, 
MemoryMode.OFF_HEAP)
+      assert(offHeapAcquired <= 1200L, "Off-heap memory should be reduced by 
unmanaged usage")
+      mm.releaseExecutionMemory(offHeapAcquired, taskAttemptId, 
MemoryMode.OFF_HEAP)
+
+      // Test off-heap storage memory
+      unmanagedConsumer.memoryUsed = 500L
+      Thread.sleep(200)
+
+      // Storage should fail when block size + unmanaged memory > max off-heap 
memory
+      assert(!mm.acquireStorageMemory(dummyBlock, 1100L, MemoryMode.OFF_HEAP))
+
+      // But smaller off-heap storage requests should succeed
+      assert(mm.acquireStorageMemory(dummyBlock, 900L, MemoryMode.OFF_HEAP))
+      mm.releaseStorageMemory(900L, MemoryMode.OFF_HEAP)
+
+      // Test that on-heap is NOT affected by off-heap unmanaged memory
+      val onHeapAcquired = mm.acquireExecutionMemory(600L, taskAttemptId, 
MemoryMode.ON_HEAP)
+      assert(onHeapAcquired == 600L,
+        "On-heap memory should not be reduced by off-heap unmanaged usage")
+      mm.releaseExecutionMemory(onHeapAcquired, taskAttemptId, 
MemoryMode.ON_HEAP)
+
+      // Test with mixed memory modes
+      unmanagedConsumer.memoryUsed = 200L
+      Thread.sleep(200)
+
+      // Allocate some on-heap and off-heap memory
+      val onHeap = mm.acquireExecutionMemory(400L, taskAttemptId, 
MemoryMode.ON_HEAP)
+      val offHeap = mm.acquireExecutionMemory(1000L, taskAttemptId, 
MemoryMode.OFF_HEAP)
+
+      assert(onHeap == 400L && offHeap <= 1300L,
+        "Off-heap memory pool should respect unmanaged memory usage, on-heap 
should not")
+
+    } finally {
+      UnifiedMemoryManager.shutdownUnmanagedMemoryPoller()
+      UnifiedMemoryManager.clearUnmanagedMemoryUsers()
+    }
   }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
index 90662cbb6ca7..dd8a99500a1a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
@@ -27,6 +27,7 @@ import java.util.concurrent.atomic.{AtomicBoolean, 
AtomicInteger, AtomicLong}
 import scala.collection.{mutable, Map}
 import scala.jdk.CollectionConverters.ConcurrentMapHasAsScala
 import scala.util.Try
+import scala.util.control.NonFatal
 
 import org.apache.hadoop.conf.Configuration
 import org.json4s.{Formats, NoTypeHints}
@@ -73,7 +74,8 @@ class RocksDB(
     useColumnFamilies: Boolean = false,
     enableStateStoreCheckpointIds: Boolean = false,
     partitionId: Int = 0,
-    eventForwarder: Option[RocksDBEventForwarder] = None) extends Logging {
+    eventForwarder: Option[RocksDBEventForwarder] = None,
+    uniqueId: String = "") extends Logging {
 
   import RocksDB._
 
@@ -181,6 +183,24 @@ class RocksDB(
   protected var sessionStateStoreCkptId: Option[String] = None
   protected[sql] val lineageManager: RocksDBLineageManager = new 
RocksDBLineageManager
 
+  // Memory tracking fields for unmanaged memory monitoring
+  // This allows the UnifiedMemoryManager to track RocksDB memory usage without
+  // directly accessing RocksDB from the polling thread, avoiding segmentation 
faults
+
+  // Timestamp of the last memory usage update in milliseconds.
+  // Used to enforce the update interval and prevent excessive memory queries.
+  private val lastMemoryUpdateTime = new AtomicLong(0L)
+
+  // Minimum interval between memory usage updates in milliseconds.
+  // This prevents performance impact from querying RocksDB memory too 
frequently.
+  private val memoryUpdateIntervalMs = conf.memoryUpdateIntervalMs
+
+  // Register with RocksDBMemoryManager if we have a unique ID
+  if (uniqueId.nonEmpty) {
+    // Initial registration with zero memory usage
+    RocksDBMemoryManager.updateMemoryUsage(uniqueId, 0L, 
conf.boundedMemoryUsage)
+  }
+
   @volatile private var numKeysOnLoadedVersion = 0L
   @volatile private var numKeysOnWritingVersion = 0L
 
@@ -573,6 +593,10 @@ class RocksDB(
     } else {
       loadWithoutCheckpointId(version, readOnly)
     }
+
+    // Register with memory manager after successful load
+    updateMemoryUsageIfNeeded()
+
     this
   }
 
@@ -754,6 +778,7 @@ class RocksDB(
   def get(
       key: Array[Byte],
       cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Array[Byte] = {
+    updateMemoryUsageIfNeeded()
     val keyWithPrefix = if (useColumnFamilies) {
       encodeStateRowWithPrefix(key, cfName)
     } else {
@@ -821,6 +846,7 @@ class RocksDB(
       value: Array[Byte],
       cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME,
       includesPrefix: Boolean = false): Unit = {
+    updateMemoryUsageIfNeeded()
     val keyWithPrefix = if (useColumnFamilies && !includesPrefix) {
       encodeStateRowWithPrefix(key, cfName)
     } else {
@@ -848,6 +874,7 @@ class RocksDB(
       value: Array[Byte],
       cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME,
       includesPrefix: Boolean = false): Unit = {
+    updateMemoryUsageIfNeeded()
     val keyWithPrefix = if (useColumnFamilies && !includesPrefix) {
       encodeStateRowWithPrefix(key, cfName)
     } else {
@@ -867,6 +894,7 @@ class RocksDB(
       key: Array[Byte],
       cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME,
       includesPrefix: Boolean = false): Unit = {
+    updateMemoryUsageIfNeeded()
     val keyWithPrefix = if (useColumnFamilies && !includesPrefix) {
       encodeStateRowWithPrefix(key, cfName)
     } else {
@@ -882,6 +910,7 @@ class RocksDB(
    * Get an iterator of all committed and uncommitted key-value pairs.
    */
   def iterator(): Iterator[ByteArrayPair] = {
+    updateMemoryUsageIfNeeded()
     val iter = db.newIterator()
     logInfo(log"Getting iterator from version ${MDC(LogKeys.LOADED_VERSION, 
loadedVersion)}")
     iter.seekToFirst()
@@ -918,6 +947,7 @@ class RocksDB(
    * Get an iterator of all committed and uncommitted key-value pairs for the 
given column family.
    */
   def iterator(cfName: String): Iterator[ByteArrayPair] = {
+    updateMemoryUsageIfNeeded()
     if (!useColumnFamilies) {
       iterator()
     } else {
@@ -967,6 +997,7 @@ class RocksDB(
   def prefixScan(
       prefix: Array[Byte],
       cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME): 
Iterator[ByteArrayPair] = {
+    updateMemoryUsageIfNeeded()
     val iter = db.newIterator()
     val updatedPrefix = if (useColumnFamilies) {
       encodeStateRowWithPrefix(prefix, cfName)
@@ -1013,6 +1044,7 @@ class RocksDB(
    * - Sync the checkpoint dir files to DFS
    */
   def commit(): (Long, StateStoreCheckpointInfo) = {
+    updateMemoryUsageIfNeeded()
     val newVersion = loadedVersion + 1
     try {
       logInfo(log"Flushing updates for ${MDC(LogKeys.VERSION_NUM, 
newVersion)}")
@@ -1227,6 +1259,17 @@ class RocksDB(
         snapshot = snapshotsToUploadQueue.poll()
       }
 
+      // Unregister from RocksDBMemoryManager
+      if (uniqueId.nonEmpty) {
+        try {
+          RocksDBMemoryManager.unregisterInstance(uniqueId)
+        } catch {
+          case NonFatal(e) =>
+            logWarning(log"Failed to unregister from RocksDBMemoryManager " +
+              log"${MDC(LogKeys.EXCEPTION, e)}")
+        }
+      }
+
       silentDeleteRecursively(localRootDir, "closing RocksDB")
       // Clear internal maps to reset the state
       clearColFamilyMaps()
@@ -1339,6 +1382,53 @@ class RocksDB(
 
   private def getDBProperty(property: String): Long = 
db.getProperty(property).toLong
 
+  /**
+   * Returns the current memory usage of this RocksDB instance in bytes.
+   * WARNING: This method should only be called from the task thread when
+   * RocksDB is in a safe state.
+   *
+   * This includes memory from all major RocksDB components:
+   * - Table readers (indexes and filters in memory)
+   * - Memtables (write buffers)
+   * - Block cache (cached data blocks)
+   * - Block cache pinned usage (blocks pinned in cache)
+   *
+   * @return Total memory usage in bytes across all tracked components
+   */
+  def getMemoryUsage: Long = {
+
+    require(db != null && !db.isClosed, "RocksDB must be open to get memory 
usage")
+    RocksDB.mainMemorySources.map { memorySource =>
+      getDBProperty(memorySource)
+    }.sum
+  }
+
+  /**
+   * Updates the cached memory usage if enough time has passed.
+   * This is called from task thread operations, so it's already thread-safe.
+   */
+  def updateMemoryUsageIfNeeded(): Unit = {
+    if (uniqueId.isEmpty) return // No tracking without unique ID
+
+    val currentTime = System.currentTimeMillis()
+    val timeSinceLastUpdate = currentTime - lastMemoryUpdateTime.get()
+
+    if (timeSinceLastUpdate >= memoryUpdateIntervalMs && db != null && 
!db.isClosed) {
+      try {
+        val usage = getMemoryUsage
+        lastMemoryUpdateTime.set(currentTime)
+        // Report usage to RocksDBMemoryManager
+        RocksDBMemoryManager.updateMemoryUsage(
+          uniqueId,
+          usage,
+          conf.boundedMemoryUsage)
+      } catch {
+        case NonFatal(e) =>
+          logDebug(s"Failed to update RocksDB memory usage: ${e.getMessage}")
+      }
+    }
+  }
+
   private def openDB(): Unit = {
     assert(db == null)
     db = NativeRocksDB.open(rocksDbOptions, workingDir.toString)
@@ -1458,6 +1548,13 @@ class RocksDB(
 }
 
 object RocksDB extends Logging {
+
+  val mainMemorySources: Seq[String] = Seq(
+    "rocksdb.estimate-table-readers-mem",
+    "rocksdb.cur-size-all-mem-tables",
+    "rocksdb.block-cache-usage",
+    "rocksdb.block-cache-pinned-usage")
+
   case class RocksDBSnapshot(
       checkpointDir: File,
       version: Long,
@@ -1699,6 +1796,7 @@ case class RocksDBConf(
     totalMemoryUsageMB: Long,
     writeBufferCacheRatio: Double,
     highPriorityPoolRatio: Double,
+    memoryUpdateIntervalMs: Long,
     compressionCodec: String,
     allowFAllocate: Boolean,
     compression: String,
@@ -1785,6 +1883,12 @@ object RocksDBConf {
   private val HIGH_PRIORITY_POOL_RATIO_CONF = 
SQLConfEntry(HIGH_PRIORITY_POOL_RATIO_CONF_KEY,
     "0.1")
 
+  // Memory usage update interval for unmanaged memory tracking
+  val MEMORY_UPDATE_INTERVAL_MS_CONF_KEY = "memoryUpdateIntervalMs"
+  private val MEMORY_UPDATE_INTERVAL_MS_CONF = 
SQLConfEntry(MEMORY_UPDATE_INTERVAL_MS_CONF_KEY,
+    "1000")
+
+
   // Allow files to be pre-allocated on disk using fallocate
   // Disabling may slow writes, but can solve an issue where
   // significant quantities of disk are wasted if there are
@@ -1883,6 +1987,7 @@ object RocksDBConf {
       getLongConf(MAX_MEMORY_USAGE_MB_CONF),
       getRatioConf(WRITE_BUFFER_CACHE_RATIO_CONF),
       getRatioConf(HIGH_PRIORITY_POOL_RATIO_CONF),
+      getPositiveLongConf(MEMORY_UPDATE_INTERVAL_MS_CONF),
       storeConf.compressionCodec,
       getBooleanConf(ALLOW_FALLOCATE_CONF),
       getStringConf(COMPRESSION_CONF),
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBMemoryManager.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBMemoryManager.scala
index 273cbbc5e87d..80ad864600b2 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBMemoryManager.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBMemoryManager.scala
@@ -17,22 +17,93 @@
 
 package org.apache.spark.sql.execution.streaming.state
 
+import java.util.concurrent.ConcurrentHashMap
+
+import scala.jdk.CollectionConverters._
+
 import org.rocksdb._
 
+import org.apache.spark.SparkEnv
 import org.apache.spark.internal.{Logging, MDC}
 import org.apache.spark.internal.LogKeys._
+import org.apache.spark.memory.{MemoryMode, UnifiedMemoryManager, 
UnmanagedMemoryConsumer, UnmanagedMemoryConsumerId}
 
 /**
  * Singleton responsible for managing cache and write buffer manager 
associated with all RocksDB
  * state store instances running on a single executor if boundedMemoryUsage is 
enabled for RocksDB.
  * If boundedMemoryUsage is disabled, a new cache object is returned.
+ * This also implements UnmanagedMemoryConsumer to report RocksDB memory usage 
to Spark's
+ * UnifiedMemoryManager, allowing Spark to account for RocksDB memory when 
making
+ * memory allocation decisions.
  */
-object RocksDBMemoryManager extends Logging {
+object RocksDBMemoryManager extends Logging with UnmanagedMemoryConsumer{
   private var writeBufferManager: WriteBufferManager = null
   private var cache: Cache = null
 
+  // Tracks memory usage and bounded memory mode per unique ID
+  private case class InstanceMemoryInfo(memoryUsage: Long, isBoundedMemory: 
Boolean)
+  private val instanceMemoryMap = new ConcurrentHashMap[String, 
InstanceMemoryInfo]()
+
+  override def unmanagedMemoryConsumerId: UnmanagedMemoryConsumerId = {
+    UnmanagedMemoryConsumerId("RocksDB", "RocksDB-Memory-Manager")
+  }
+
+  override def memoryMode: MemoryMode = {
+    // RocksDB uses native/off-heap memory for its data structures
+    MemoryMode.OFF_HEAP
+  }
+
+  override def getMemBytesUsed: Long = {
+    val memoryInfos = instanceMemoryMap.values().asScala.toSeq
+    if (memoryInfos.isEmpty) {
+      return 0L
+    }
+
+    // Separate instances by bounded vs unbounded memory mode
+    val (bounded, unbounded) = memoryInfos.partition(_.isBoundedMemory)
+
+    // For bounded memory instances, they all share the same memory pool,
+    // so just take the max value (they should all be similar)
+    val boundedMemory = if (bounded.nonEmpty) bounded.map(_.memoryUsage).max 
else 0L
+
+    // For unbounded memory instances, sum their individual usages
+    val unboundedMemory = unbounded.map(_.memoryUsage).sum
+
+    // Total is bounded memory (shared) + sum of unbounded memory (individual)
+    boundedMemory + unboundedMemory
+  }
+
+  /**
+   * Register/update a RocksDB instance with its memory usage.
+   * @param uniqueId The instance's unique identifier
+   * @param memoryUsage The current memory usage in bytes
+   * @param isBoundedMemory Whether this instance uses bounded memory mode
+   */
+  def updateMemoryUsage(
+      uniqueId: String,
+      memoryUsage: Long,
+      isBoundedMemory: Boolean): Unit = {
+    instanceMemoryMap.put(uniqueId, InstanceMemoryInfo(memoryUsage, 
isBoundedMemory))
+    logDebug(s"Updated memory usage for $uniqueId: $memoryUsage bytes " +
+      s"(bounded=$isBoundedMemory)")
+  }
+
+  /**
+   * Unregister a RocksDB instance.
+   * @param uniqueId The instance's unique identifier
+   */
+  def unregisterInstance(uniqueId: String): Unit = {
+    instanceMemoryMap.remove(uniqueId)
+    logDebug(s"Unregistered instance $uniqueId")
+  }
+
   def getOrCreateRocksDBMemoryManagerAndCache(conf: RocksDBConf): 
(WriteBufferManager, Cache)
     = synchronized {
+    // Register with UnifiedMemoryManager (idempotent operation)
+    if (SparkEnv.get != null) {
+      UnifiedMemoryManager.registerUnmanagedMemoryConsumer(this)
+    }
+
     if (conf.boundedMemoryUsage) {
       if (writeBufferManager == null) {
         assert(cache == null)
@@ -72,5 +143,6 @@ object RocksDBMemoryManager extends Logging {
   def resetWriteBufferManagerAndCache: Unit = synchronized {
     writeBufferManager = null
     cache = null
+    instanceMemoryMap.clear()
   }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
index a702d041de7e..c0eef53e12ba 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
@@ -562,6 +562,7 @@ private[sql] class RocksDBStateStoreProvider
     this.rocksDBEventForwarder =
       Some(RocksDBEventForwarder(StateStoreProvider.getRunId(hadoopConf), 
stateStoreId))
 
+    // Initialize StateStoreProviderId for memory tracking
     val queryRunId = UUID.fromString(StateStoreProvider.getRunId(hadoopConf))
     this.stateStoreProviderId = StateStoreProviderId(stateStoreId, queryRunId)
 
@@ -775,7 +776,8 @@ private[sql] class RocksDBStateStoreProvider
       useColumnFamilies: Boolean,
       enableStateStoreCheckpointIds: Boolean,
       partitionId: Int = 0,
-      eventForwarder: Option[RocksDBEventForwarder] = None): RocksDB = {
+      eventForwarder: Option[RocksDBEventForwarder] = None,
+      uniqueId: String = ""): RocksDB = {
     new RocksDB(
       dfsRootDir,
       conf,
@@ -785,7 +787,8 @@ private[sql] class RocksDBStateStoreProvider
       useColumnFamilies,
       enableStateStoreCheckpointIds,
       partitionId,
-      eventForwarder)
+      eventForwarder,
+      uniqueId)
   }
 
   private[sql] lazy val rocksDB = {
@@ -797,7 +800,7 @@ private[sql] class RocksDBStateStoreProvider
     val localRootDir = Utils.createTempDir(Utils.getLocalDir(sparkConf), 
storeIdStr)
     createRocksDB(dfsRootDir, RocksDBConf(storeConf), localRootDir, 
hadoopConf, loggingId,
       useColumnFamilies, storeConf.enableStateStoreCheckpointIds, 
stateStoreId.partitionId,
-      rocksDBEventForwarder)
+      rocksDBEventForwarder, stateStoreProviderId.toString)
   }
 
   private val keyValueEncoderMap = new 
java.util.concurrent.ConcurrentHashMap[String,
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FailureInjectionCheckpointFileManager.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FailureInjectionCheckpointFileManager.scala
index 9429cd5ef39e..fb698c89ff8e 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FailureInjectionCheckpointFileManager.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FailureInjectionCheckpointFileManager.scala
@@ -258,7 +258,8 @@ class FailureInjectionRocksDBStateStoreProvider extends 
RocksDBStateStoreProvide
       useColumnFamilies: Boolean,
       enableStateStoreCheckpointIds: Boolean,
       partitionId: Int,
-      eventForwarder: Option[RocksDBEventForwarder] = None): RocksDB = {
+      eventForwarder: Option[RocksDBEventForwarder] = None,
+      uniqueId: String): RocksDB = {
     FailureInjectionRocksDBStateStoreProvider.createRocksDBWithFaultInjection(
       dfsRootDir,
       conf,
@@ -268,7 +269,8 @@ class FailureInjectionRocksDBStateStoreProvider extends 
RocksDBStateStoreProvide
       useColumnFamilies,
       enableStateStoreCheckpointIds,
       partitionId,
-      eventForwarder)
+      eventForwarder,
+      uniqueId)
   }
 }
 
@@ -286,7 +288,8 @@ object FailureInjectionRocksDBStateStoreProvider {
       useColumnFamilies: Boolean,
       enableStateStoreCheckpointIds: Boolean,
       partitionId: Int,
-      eventForwarder: Option[RocksDBEventForwarder]): RocksDB = {
+      eventForwarder: Option[RocksDBEventForwarder],
+      uniqueId: String): RocksDB = {
     new RocksDB(
       dfsRootDir,
       conf = conf,
@@ -296,7 +299,8 @@ object FailureInjectionRocksDBStateStoreProvider {
       useColumnFamilies = useColumnFamilies,
       enableStateStoreCheckpointIds = enableStateStoreCheckpointIds,
       partitionId = partitionId,
-      eventForwarder = eventForwarder
+      eventForwarder = eventForwarder,
+      uniqueId
     ) {
       override def createFileManager(
           dfsRootDir: String,
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBCheckpointFailureInjectionSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBCheckpointFailureInjectionSuite.scala
index 0c3e457c8df1..4018971d20f4 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBCheckpointFailureInjectionSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBCheckpointFailureInjectionSuite.scala
@@ -602,7 +602,8 @@ class RocksDBCheckpointFailureInjectionSuite extends 
StreamTest
         useColumnFamilies = true,
         enableStateStoreCheckpointIds = enableStateStoreCheckpointIds,
         partitionId = 0,
-        eventForwarder = None)
+        eventForwarder = None,
+        uniqueId = "")
       db.load(version, checkpointId)
       func(db)
     } finally {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala
index e0af281fecb9..d2c95dfe5016 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala
@@ -24,7 +24,7 @@ import scala.jdk.CollectionConverters.SetHasAsScala
 import org.scalatest.time.{Minute, Span}
 
 import org.apache.spark.sql.execution.streaming.{MemoryStream, 
StreamingQueryWrapper}
-import org.apache.spark.sql.functions.count
+import org.apache.spark.sql.functions.{count, max}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming._
 import org.apache.spark.sql.streaming.OutputMode.Update
@@ -314,4 +314,80 @@ class RocksDBStateStoreIntegrationSuite extends StreamTest
     assert(changelogVersionsPresent(dirForPartition0) == List(3L, 4L))
     assert(snapshotVersionsPresent(dirForPartition0).contains(5L))
   }
+
+  // Test with both bounded memory enabled and disabled
+  Seq(true, false).foreach { boundedMemoryEnabled =>
+    test(s"RocksDB memory tracking integration with UnifiedMemoryManager" +
+      s" with boundedMemory=$boundedMemoryEnabled") {
+      withTempDir { dir =>
+        withSQLConf(
+          (SQLConf.STATE_STORE_PROVIDER_CLASS.key -> 
classOf[RocksDBStateStoreProvider].getName),
+          (SQLConf.CHECKPOINT_LOCATION.key -> dir.getCanonicalPath),
+          (SQLConf.SHUFFLE_PARTITIONS.key -> "5"),
+          (SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> (5 * 60 * 
1000).toString),
+          ("spark.memory.unmanagedMemoryPollingInterval" -> "100ms"),
+          ("spark.sql.streaming.stateStore.rocksdb.boundedMemoryUsage" ->
+            boundedMemoryEnabled.toString)) {
+
+          import org.apache.spark.memory.UnifiedMemoryManager
+          import org.apache.spark.sql.streaming.Trigger
+
+          // Use rate stream to ensure continuous state operations that 
trigger memory updates
+          val query = spark.readStream
+            .format("rate")
+            .option("rowsPerSecond", "10") // Continuous but not overwhelming
+            .load()
+            .selectExpr("value % 100 as key", "value")
+            .groupBy("key")
+            .agg(count("*").as("count"), max("value").as("max_value"))
+            .writeStream
+            .format("console")
+            .outputMode("update")
+            .trigger(Trigger.ProcessingTime(200)) // Regular triggers to 
ensure state operations
+            .start()
+
+          try {
+            // Let the stream run to establish RocksDB instances and generate 
state operations
+            Thread.sleep(2000) // 2 seconds should be enough for several 
processing cycles
+
+            // Now check for memory tracking - the continuous stream should 
trigger memory updates
+            var rocksDBMemory = 0L
+            var attempts = 0
+            val maxAttempts = 15 // 15 attempts with 1-second intervals = 15 
seconds max
+
+            while (rocksDBMemory <= 0L && attempts < maxAttempts) {
+              Thread.sleep(1000) // Wait between checks to allow memory updates
+              rocksDBMemory = 
UnifiedMemoryManager.getMemoryByComponentType("RocksDB")
+              attempts += 1
+
+              if (rocksDBMemory > 0L) {
+                logInfo(s"RocksDB memory detected: $rocksDBMemory bytes " +
+                  s"after $attempts attempts with 
boundedMemory=$boundedMemoryEnabled")
+              }
+            }
+
+            // Verify memory tracking remains stable during continued operation
+            Thread.sleep(2000) // Let stream continue running
+
+            val finalMemory = 
UnifiedMemoryManager.getMemoryByComponentType("RocksDB")
+
+            // Memory should still be tracked (allow for some fluctuation)
+            assert(finalMemory > 0L,
+              s"RocksDB memory tracking should remain active during stream 
processing: " +
+                s"got $finalMemory bytes (initial: $rocksDBMemory) " +
+                s"with boundedMemory=$boundedMemoryEnabled")
+
+            logInfo(s"RocksDB memory tracking test completed successfully: " +
+              s"initial=$rocksDBMemory bytes, final=$finalMemory bytes " +
+              s"with boundedMemory=$boundedMemoryEnabled")
+
+          } finally {
+            query.stop()
+            // Clean up unmanaged memory users
+            UnifiedMemoryManager.clearUnmanagedMemoryUsers()
+          }
+        }
+      }
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to