This is an automated email from the ASF dual-hosted git repository.

HeartSaVioR pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a6ac0b8109c0 [SPARK-57141][SS][RTM][STREAMINGSHUFFLE][PART3] Add 
StreamingShuffleManager and MultiShuffleManager
a6ac0b8109c0 is described below

commit a6ac0b8109c02969d685908c37062566653918cc
Author: Boyang Jerry Peng <[email protected]>
AuthorDate: Mon Jun 8 07:13:16 2026 +0900

    [SPARK-57141][SS][RTM][STREAMINGSHUFFLE][PART3] Add StreamingShuffleManager 
and MultiShuffleManager
    
    ### What changes were proposed in this pull request?
    
      This is **part 3** of a multi-PR effort to add *streaming shuffle* to 
Spark — a push-based shuffle used by Real-Time Mode (RTM) structured streaming, 
where writer tasks push records
      directly to reader tasks over the network instead of writing map output 
to disk for readers to pull.
    
      This PR adds the shuffle-manager layer that later PRs plug into:
    
      - **`StreamingShuffleManager`** — a `ShuffleManager` implementation for 
streaming shuffle. `getWriter`/`getReader` are intentionally stubbed in this PR 
(they throw
      `UnsupportedOperationException`) and are implemented in the push-path / 
pull-path PRs that follow.
      - **`MultiShuffleManager`** — routes each shuffle to either the batch 
`SortShuffleManager` or the `StreamingShuffleManager`, based on a per-query 
local property, so a single application
      can mix batch and streaming shuffle.
      - **`TaskContextAwareLogging`** — a `Logging` mixin that prefixes log 
lines with queryId / shuffleId / stageId / taskId.
      - **`SparkEnv`** — exposes the `StreamingShuffleOutputTracker` (added in 
part 2) to executors, and initializes it **only** when the configured shuffle 
manager is `StreamingShuffleManager`
      or `MultiShuffleManager`.
      - Two streaming-shuffle error conditions 
(`STREAMING_SHUFFLE_INCORRECT_SEQUENCE_NUMBER`, 
`STREAMING_SHUFFLE_UNEXPECTED_MESSAGE_TYPE`) and the `STREAMING_QUERY_ID` log 
key.
    
      The full PR stack:
    
      - **Part 1** (SPARK-56674, *merged*) — streaming shuffle wire protocol 
(Netty messages).
      - **Part 2** (SPARK-56962, *merged*) — `StreamingShuffleOutputTracker` 
(driver-side writer-location coordination).
      - **Part 3** (*this PR*) — shuffle-manager layer 
(`StreamingShuffleManager` + `MultiShuffleManager`), logging mixin, and 
SparkEnv tracker wiring.
      - **Part 4** — `StreamingShuffleWriter` + server-side Netty handler (push 
path).
      - **Part 5** — `StreamingShuffleReader` + client-side Netty handler (pull 
path).
      - **Part 6** — register streaming shuffles with the tracker in 
`DAGScheduler` (activation).
      - **Part 7** — end-to-end `StreamingShuffleSuite`.
      - **Part 8** — documentation.
    
      ### Why are the changes needed?
    
      Real-Time Mode / low-latency continuous queries need shuffle data to flow 
continuously between stages. The default sort shuffle (write map output to 
disk, then have reducers pull it) adds
      latency that is unacceptable for these workloads. Streaming shuffle 
instead pushes records directly from writer tasks to reader tasks.
    
      This PR lands the manager layer that the writer and reader 
implementations attach to, plus `MultiShuffleManager` so batch stages keep 
using the sort shuffle while streaming stages use the
      streaming shuffle within the same application.
    
      ### Does this PR introduce _any_ user-facing change?
    
      No. The new shuffle managers are opt-in via `spark.shuffle.manager` and 
are not the default; `getWriter`/`getReader` are still stubbed in this PR, so 
the feature is not yet usable
      end-to-end (completed in later PRs). The `StreamingShuffleOutputTracker` 
is initialized only when one of the new managers is configured, so there is no 
change to the default (sort
      shuffle) path — this is covered by tests.
    
      ### How was this patch tested?
    
      New unit suites:
    
      - **`StreamingShuffleManagerSuite`** — `getWriterId` for data/termination 
messages and the unexpected-message-type error; `getQueryId` resolution and 
failure; `registerShuffle` handle
      type; and SparkEnv gating (tracker is present for 
`StreamingShuffleManager`, absent for the default manager).
      - **`MultiShuffleManagerSuite`** — per-query streaming-vs-batch routing, 
the enable property, and SparkEnv gating for `MultiShuffleManager`.
    
      13 tests, all passing. `SparkThrowableSuite` validates the two new error 
conditions.
    
      ### Was this patch authored or co-authored using generative AI tooling?
    
      Co-authored with Claude Code (Claude Opus 4.8)
    
    Closes #56196 from jerrypeng/stack/streaming-shuffle-pr3-managers.
    
    Authored-by: Boyang Jerry Peng <[email protected]>
    Signed-off-by: Jungtaek Lim <[email protected]>
---
 .../java/org/apache/spark/internal/LogKeys.java    |   1 +
 .../src/main/resources/error/error-conditions.json |  12 ++
 .../src/main/scala/org/apache/spark/SparkEnv.scala |  44 ++++++
 .../shuffle/streaming/MultiShuffleManager.scala    | 154 +++++++++++++++++++++
 .../streaming/StreamingShuffleManager.scala        | 130 +++++++++++++++++
 .../streaming/TaskContextAwareLogging.scala        | 109 +++++++++++++++
 .../streaming/MultiShuffleManagerSuite.scala       |  71 ++++++++++
 .../streaming/StreamingShuffleManagerSuite.scala   | 125 +++++++++++++++++
 8 files changed, 646 insertions(+)

diff --git 
a/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java 
b/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java
index d8ce9d025af9..37064bf77631 100644
--- a/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java
+++ b/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java
@@ -794,6 +794,7 @@ public enum LogKeys implements LogKey {
   STREAMING_DATA_SOURCE_NAME,
   STREAMING_OFFSETS_END,
   STREAMING_OFFSETS_START,
+  STREAMING_QUERY_ID,
   STREAMING_QUERY_PROGRESS,
   STREAMING_SOURCE,
   STREAMING_TABLE,
diff --git a/common/utils/src/main/resources/error/error-conditions.json 
b/common/utils/src/main/resources/error/error-conditions.json
index 4e66798d4039..bee2acec18f9 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -7211,6 +7211,18 @@
     },
     "sqlState" : "0A000"
   },
+  "STREAMING_SHUFFLE_INCORRECT_SEQUENCE_NUMBER" : {
+    "message" : [
+      "Streaming shuffle <messageType> between writer <writerId> and reader 
<readerId> expected to have sequence number <expSeqNum>, but the actual 
sequence number is <actSeqNum>. Please verify that the messages are sent in 
order."
+    ],
+    "sqlState" : "XXKST"
+  },
+  "STREAMING_SHUFFLE_UNEXPECTED_MESSAGE_TYPE" : {
+    "message" : [
+      "Unexpected message type <messageType> encountered during streaming 
shuffle."
+    ],
+    "sqlState" : "XXKST"
+  },
   "STREAMING_STATEFUL_OPERATOR_MISSING_STATE_DIRECTORY" : {
     "message" : [
       "Cannot restart streaming query with stateful operators because the 
state directory is empty or missing.",
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala 
b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 4e56c88501ed..9c4abdf66579 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -46,6 +46,7 @@ import 
org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinato
 import org.apache.spark.security.CryptoStreamUtils
 import org.apache.spark.serializer.{JavaSerializer, Serializer, 
SerializerManager}
 import org.apache.spark.shuffle.ShuffleManager
+import org.apache.spark.shuffle.streaming.{MultiShuffleManager, 
StreamingShuffleManager}
 import org.apache.spark.storage._
 import org.apache.spark.udf.worker.UDFWorkerSpecification
 import org.apache.spark.udf.worker.core.{UDFDispatcherFactory, 
UDFDispatcherManager, WorkerDispatcher}
@@ -181,6 +182,7 @@ class SparkEnv (
       pythonWorkers.values.foreach(_.stop())
       udfDispatcherManager.foreach(_.close())
       mapOutputTracker.stop()
+      _streamingShuffleOutputTracker.foreach(_.stop())
       if (shuffleManager != null) {
         shuffleManager.stop()
       }
@@ -299,6 +301,48 @@ class SparkEnv (
       // Signal that the ShuffleManager has been initialized
       shuffleManagerInitLatch.countDown()
     }
+    initializeStreamingShuffleOutputTracker()
+  }
+
+  // Holds the streaming shuffle output tracker, which is only present when 
the configured
+  // shuffle manager requires it (i.e., StreamingShuffleManager or 
MultiShuffleManager).
+  @volatile private var _streamingShuffleOutputTracker: 
Option[StreamingShuffleOutputTracker] =
+    None
+
+  def streamingShuffleOutputTracker: Option[StreamingShuffleOutputTracker] =
+    _streamingShuffleOutputTracker
+
+  /**
+   * Initialize the StreamingShuffleOutputTracker if the configured shuffle 
manager requires one
+   * and one does not already exist. This method is idempotent -- calling it 
multiple times is safe.
+   */
+  private def initializeStreamingShuffleOutputTracker(): Unit = {
+    if (_streamingShuffleOutputTracker.isDefined) {
+      return
+    }
+
+    val shuffleManagerName = ShuffleManager.getShuffleManagerClassName(conf)
+    if (shuffleManagerName == classOf[StreamingShuffleManager].getName
+        || shuffleManagerName == classOf[MultiShuffleManager].getName) {
+      val tracker = if (SparkContext.isDriver(executorId)) {
+        new StreamingShuffleOutputTrackerMaster(conf)
+      } else {
+        new StreamingShuffleOutputTrackerWorker(conf)
+      }
+
+      if (SparkContext.isDriver(executorId)) {
+        tracker.trackerEndpoint = rpcEnv.setupEndpoint(
+          StreamingShuffleOutputTracker.ENDPOINT_NAME,
+          new StreamingShuffleOutputTrackerMasterEndpoint(
+            rpcEnv,
+            tracker.asInstanceOf[StreamingShuffleOutputTrackerMaster],
+            conf))
+      } else {
+        tracker.trackerEndpoint = RpcUtils.makeDriverRef(
+          StreamingShuffleOutputTracker.ENDPOINT_NAME, conf, rpcEnv)
+      }
+      _streamingShuffleOutputTracker = Some(tracker)
+    }
   }
 
   private[spark] def initializeMemoryManager(numUsableCores: Int): Unit = {
diff --git 
a/core/src/main/scala/org/apache/spark/shuffle/streaming/MultiShuffleManager.scala
 
b/core/src/main/scala/org/apache/spark/shuffle/streaming/MultiShuffleManager.scala
new file mode 100644
index 000000000000..9e63c9375955
--- /dev/null
+++ 
b/core/src/main/scala/org/apache/spark/shuffle/streaming/MultiShuffleManager.scala
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.streaming
+
+import java.util.Properties
+import java.util.concurrent.ConcurrentHashMap
+
+import org.apache.spark.{ShuffleDependency, SparkConf, SparkContext, 
SparkException, TaskContext}
+import org.apache.spark.internal.Logging
+import org.apache.spark.shuffle.{ShuffleBlockResolver, ShuffleHandle, 
ShuffleManager, ShuffleReader, ShuffleReadMetricsReporter, 
ShuffleWriteMetricsReporter, ShuffleWriter}
+import org.apache.spark.shuffle.sort.SortShuffleManager
+import 
org.apache.spark.shuffle.streaming.MultiShuffleManager.isStreamingShuffleEnabled
+
+class MultiShuffleHandle(
+    val streamingShuffleHandle: ShuffleHandle,
+    val otherShuffleHandle: ShuffleHandle)
+  extends ShuffleHandle(streamingShuffleHandle.shuffleId)
+
+object MultiShuffleManager {
+  // Streaming shuffle is used for queries running in Real-Time Mode 
(concurrent stages), gated by
+  // the same per-query local property that the RTM micro-batch execution sets.
+  // TODO(SPARK-57000): once ConcurrentStageDAGScheduler is merged 
(apache/spark#56055), reference
+  // ConcurrentStageDAGScheduler.CONCURRENT_STAGES_ENABLED_PROPERTY here (and 
delegate to
+  // ConcurrentStageDAGScheduler.isConcurrentStagesEnabled) instead of 
hardcoding the property.
+  val STREAMING_SHUFFLE_ENABLED_PROPERTY = 
"streaming.concurrent.stages.enabled"
+
+  def isStreamingShuffleEnabled(properties: Properties): Boolean =
+    "true" == properties.getProperty(STREAMING_SHUFFLE_ENABLED_PROPERTY)
+}
+
+/* This shuffle manager is used to allow real-time queries that depends on 
streaming shuffle
+and normal queries that depends on sort shuffle to coexist in a cluster. Right 
now, we only
+allows configuration of shuffle manager at cluster level, so consider using 
this shuffle
+manager if you want to run batch and real time queries at the same time.
+ */
+class MultiShuffleManager(conf: SparkConf) extends ShuffleManager with Logging 
{
+  // To make sure the type of shuffle manager used for a shuffle is the same 
during its lifetime
+  private val shuffleIdToManager = new ConcurrentHashMap[Int, ShuffleManager]()
+  private var streamingShuffleManager: Option[StreamingShuffleManager] = None
+  private var sortShuffleManager: Option[SortShuffleManager] = None
+
+  private def shuffleManager(shuffleId: Int): ShuffleManager = {
+    shuffleIdToManager.computeIfAbsent(shuffleId, _ => {
+      val properties = SparkContext.getActive.map(_.getLocalProperties)
+        .orElse(Option(TaskContext.get()).map(_.getLocalProperties))
+        .getOrElse(throw SparkException.internalError(
+          "Cannot determine streaming shuffle routing: no active SparkContext 
or TaskContext"))
+      if (isStreamingShuffleEnabled(properties)) {
+        if (streamingShuffleManager.isEmpty) {
+          streamingShuffleManager = Some(new StreamingShuffleManager)
+        }
+        streamingShuffleManager.get
+      } else {
+        if (sortShuffleManager.isEmpty) {
+          sortShuffleManager = Some(new SortShuffleManager(conf))
+        }
+        sortShuffleManager.get
+      }
+    })
+  }
+
+  override def registerShuffle[K, V, C](
+      shuffleId: Int,
+      dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
+    shuffleIdToManager.synchronized {
+      shuffleManager(shuffleId).registerShuffle(shuffleId, dependency)
+    }
+  }
+
+  override def getWriter[K, V](
+      handle: ShuffleHandle,
+      mapId: Long,
+      context: TaskContext,
+      metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = {
+    shuffleIdToManager.synchronized {
+      shuffleManager(handle.shuffleId).getWriter(handle, mapId, context, 
metrics)
+    }
+  }
+
+  override def getReader[K, C](
+      handle: ShuffleHandle,
+      startMapIndex: Int,
+      endMapIndex: Int,
+      startPartition: Int,
+      endPartition: Int,
+      context: TaskContext,
+      metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
+    shuffleIdToManager.synchronized {
+      shuffleManager(handle.shuffleId).getReader(
+        handle,
+        startMapIndex,
+        endMapIndex,
+        startPartition,
+        endPartition,
+        context,
+        metrics)
+    }
+  }
+
+  override def unregisterShuffle(shuffleId: Int): Boolean = {
+    shuffleIdToManager.synchronized {
+      val manager = shuffleIdToManager.get(shuffleId)
+      // During unregistering shuffle, which happens when shuffleDependency is 
garbage
+      // collected, the context might not be active anymore, in this case, we 
will
+      // perform no-op since there is no cached shuffle manager, meaning
+      // there are no other calls (i.e registerShuffle, getWriter, or 
getReader) previously
+      // invoked, thereby no state to cleanup
+      if (manager == null) {
+        return true
+      }
+
+      shuffleIdToManager.remove(shuffleId)
+      manager.unregisterShuffle(shuffleId)
+    }
+  }
+
+  override def shuffleBlockResolver: ShuffleBlockResolver = {
+    shuffleIdToManager.synchronized {
+      if (sortShuffleManager.nonEmpty) {
+        sortShuffleManager.get.shuffleBlockResolver
+      } else {
+        // don't need to support this for the streaming shuffle implementation
+        // since block manager is not used
+        throw new UnsupportedOperationException()
+      }
+    }
+  }
+
+  override def stop(): Unit = {
+    shuffleIdToManager.synchronized {
+      if (streamingShuffleManager.nonEmpty) {
+        streamingShuffleManager.get.stop()
+      }
+      if (sortShuffleManager.nonEmpty) {
+        sortShuffleManager.get.stop()
+      }
+    }
+  }
+}
diff --git 
a/core/src/main/scala/org/apache/spark/shuffle/streaming/StreamingShuffleManager.scala
 
b/core/src/main/scala/org/apache/spark/shuffle/streaming/StreamingShuffleManager.scala
new file mode 100644
index 000000000000..f56d4f0fc4f8
--- /dev/null
+++ 
b/core/src/main/scala/org/apache/spark/shuffle/streaming/StreamingShuffleManager.scala
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.streaming
+
+import org.apache.spark.{ShuffleDependency, SparkException, 
SparkRuntimeException, TaskContext}
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.shuffle.streaming.{DataMessage, 
StreamingShuffleMessage, StreamingShuffleMessageType, TerminationControlMessage}
+import org.apache.spark.shuffle._
+
+class StreamingShuffleHandle[K, V, C](shuffleId: Int, dependency: 
ShuffleDependency[K, V, C])
+  extends BaseShuffleHandle[K, V, C](shuffleId, dependency)
+
+object StreamingShuffleManager extends Logging {
+  // Exposed for testing
+  private[spark] val QUERY_ID_PROPERTY_KEY = "sql.streaming.queryId"
+  // Since above is not applicable for batch query, we use below id to track 
error for batch
+  // query with streaming shuffle
+  private val QUERY_EXECUTION_ID_PROPERTY_KEY = "spark.sql.execution.id"
+
+  def getQueryId(context: TaskContext): String = {
+    Option(context.getLocalProperty(QUERY_ID_PROPERTY_KEY))
+      
.orElse(Option(context.getLocalProperty(QUERY_EXECUTION_ID_PROPERTY_KEY)))
+      .getOrElse(throw SparkException.internalError(
+        "Streaming shuffle requires the query id or SQL execution id local 
property to be set"))
+  }
+
+  /* Called from the reader side to get the writerId associated with a message 
*/
+  def getWriterId(message: StreamingShuffleMessage): Int = {
+    message.messageType() match {
+      case StreamingShuffleMessageType.DATA_MESSAGE_UNSAFE_ROW =>
+        message.asInstanceOf[DataMessage].shuffleWriterId
+      case StreamingShuffleMessageType.TERMINATION_CONTROL_MESSAGE =>
+        message.asInstanceOf[TerminationControlMessage].shuffleWriterId
+      case _ =>
+        // Should not reach here
+        throw streamingShuffleUnexpectedMessageType(message.messageType());
+    }
+  }
+
+  def streamingShuffleIncorrectSequenceNumber(
+      messageType: StreamingShuffleMessageType,
+      writerId: Int,
+      readerId: Int,
+      expSeqNum: Long,
+      actSeqNum: Long): RuntimeException = {
+    new SparkRuntimeException(
+      errorClass = "STREAMING_SHUFFLE_INCORRECT_SEQUENCE_NUMBER",
+      messageParameters = Map(
+        "messageType" -> messageType.toString,
+        "writerId" -> writerId.toString,
+        "readerId" -> readerId.toString,
+        "expSeqNum" -> expSeqNum.toString,
+        "actSeqNum" -> actSeqNum.toString))
+  }
+
+  def streamingShuffleUnexpectedMessageType(
+      messageType: StreamingShuffleMessageType): RuntimeException = {
+    new SparkRuntimeException(
+      errorClass = "STREAMING_SHUFFLE_UNEXPECTED_MESSAGE_TYPE",
+      messageParameters = Map("messageType" -> messageType.toString))
+  }
+}
+
+private[spark] class StreamingShuffleManager extends ShuffleManager with 
Logging {
+
+  logInfo(log"Using StreamingShuffleManager")
+
+  override def registerShuffle[K, V, C](
+      shuffleId: Int,
+      dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
+    new StreamingShuffleHandle(shuffleId, dependency)
+  }
+
+  override def getWriter[K, V](
+      handle: ShuffleHandle,
+      mapId: Long,
+      context: TaskContext,
+      metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = {
+    // Implementation is added in a follow-up commit that introduces 
StreamingShuffleWriter.
+    throw new UnsupportedOperationException(
+      "StreamingShuffleManager.getWriter is not yet implemented")
+  }
+
+  /**
+   * For the streaming shuffle, the startMapIndex, endMapIndex, 
startPartition, and endPartition
+   * arguments are not relevant.
+   */
+  override def getReader[K, C](
+      handle: ShuffleHandle,
+      startMapIndex: Int,
+      endMapIndex: Int,
+      startPartition: Int,
+      endPartition: Int,
+      context: TaskContext,
+      metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
+    // Implementation is added in a follow-up commit that introduces 
StreamingShuffleReader.
+    throw new UnsupportedOperationException(
+      "StreamingShuffleManager.getReader is not yet implemented")
+  }
+
+  override def unregisterShuffle(shuffleId: Int): Boolean = {
+    // No manager-side state to release here: the driver's 
StreamingShuffleOutputTracker is
+    // unregistered in BlockManagerStorageEndpoint's RemoveShuffle handler, 
and per-task writer
+    // and reader resources are released via task completion listeners.
+    true
+  }
+
+  override def shuffleBlockResolver: ShuffleBlockResolver = {
+    // don't need to support this for the streaming shuffle implementation
+    // since block manager is not used
+    throw new UnsupportedOperationException()
+  }
+
+  override def stop(): Unit = {}
+}
diff --git 
a/core/src/main/scala/org/apache/spark/shuffle/streaming/TaskContextAwareLogging.scala
 
b/core/src/main/scala/org/apache/spark/shuffle/streaming/TaskContextAwareLogging.scala
new file mode 100644
index 000000000000..fd0ac89abc79
--- /dev/null
+++ 
b/core/src/main/scala/org/apache/spark/shuffle/streaming/TaskContextAwareLogging.scala
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.streaming
+
+import scala.concurrent.duration.Duration
+
+import org.apache.spark.TaskContext
+import org.apache.spark.internal.{LogEntry, Logging, LogKeys, 
MessageWithContext}
+
+trait TaskContextAwareLogging extends Logging {
+
+  def context: TaskContext
+
+  private val queryId: Option[String] = Option(context)
+    .flatMap(ctx => 
Option(ctx.getLocalProperty("sql.streaming.queryId")).map(_.take(5)))
+    .filter(_.nonEmpty)
+
+  @volatile private var shuffleId: Option[Int] = None
+
+  def setShuffleIdForLogging(shuffleId: Int): Unit = {
+    this.shuffleId = Some(shuffleId)
+  }
+
+  private def loadTaskId: Option[String] = {
+    Option(context)
+      .flatMap(ctx => Option(ctx.partitionId()))
+      .map(_.toString)
+  }
+
+  private def loadStageId: Option[String] = {
+    Option(context)
+      .flatMap(ctx => Option(ctx.stageId()))
+      .map(_.toString)
+  }
+
+  protected def formatMessage(
+      msg: => String,
+      taskId: Option[String] = loadTaskId,
+      stageId: Option[String] = loadStageId): String = {
+    val taskIdMsg = taskId.map(tid => s"[taskId = $tid] ").getOrElse("")
+    val stageIdMsg = stageId.map(sid => s"[stageId = $sid] ").getOrElse("")
+    val shuffleIdMsg = shuffleId.map(shid => s"[shuffleId = $shid] 
").getOrElse("")
+    val queryIdMsg = queryId.map(qid => s"[queryId = $qid] ").getOrElse("")
+    s"$queryIdMsg$shuffleIdMsg$stageIdMsg$taskIdMsg$msg"
+  }
+
+  override protected def logInfo(msg: => String): Unit =
+    super.logInfo(formatMessage(msg))
+
+  override protected def logInfo(entry: LogEntry): Unit =
+    super.logInfo(log"${MDC(LogKeys.STREAMING_QUERY_ID, 
queryId.getOrElse(""))} " +
+      log"${MDC(LogKeys.SHUFFLE_ID, shuffleId.getOrElse(-1))} " + entry)
+
+  override protected def logWarning(msg: => String): Unit =
+    super.logWarning(formatMessage(msg))
+
+  override protected def logWarning(entry: LogEntry): Unit =
+    super.logWarning(log"${MDC(LogKeys.STREAMING_QUERY_ID, 
queryId.getOrElse(""))} " +
+      log"${MDC(LogKeys.SHUFFLE_ID, shuffleId.getOrElse(-1))} " + entry)
+
+  override protected def logDebug(msg: => String): Unit =
+    super.logDebug(formatMessage(msg))
+
+  override protected def logError(msg: => String): Unit =
+    super.logError(formatMessage(msg))
+
+  override protected def logError(entry: LogEntry): Unit =
+    super.logError(log"${MDC(LogKeys.STREAMING_QUERY_ID, 
queryId.getOrElse(""))} " +
+      log"${MDC(LogKeys.SHUFFLE_ID, shuffleId.getOrElse(-1))} " + entry)
+
+  override protected def logError(entry: LogEntry, throwable: Throwable): Unit 
=
+    super.logError(log"${MDC(LogKeys.STREAMING_QUERY_ID, 
queryId.getOrElse(""))} " +
+      log"${MDC(LogKeys.SHUFFLE_ID, shuffleId.getOrElse(-1))} " + entry, 
throwable)
+
+  override protected def logError(msg: => String, throwable: Throwable): Unit =
+    super.logError(formatMessage(msg), throwable)
+
+  protected case class LogThrottler(logFn: String => Unit, interval: Duration) 
{
+    private var nextLogNanos = Long.MinValue
+    private var suppressed = 0
+
+    def apply(msg: => MessageWithContext): Unit = {
+      val now = System.nanoTime()
+      if (now >= nextLogNanos) {
+        val suffix = if (suppressed > 0) s" ($suppressed suppressed)" else ""
+        logFn(msg.message + suffix)
+        nextLogNanos = now + interval.toNanos
+        suppressed = 0
+      } else {
+        suppressed += 1
+      }
+    }
+  }
+}
diff --git 
a/core/src/test/scala/org/apache/spark/shuffle/streaming/MultiShuffleManagerSuite.scala
 
b/core/src/test/scala/org/apache/spark/shuffle/streaming/MultiShuffleManagerSuite.scala
new file mode 100644
index 000000000000..9d9e4ce1c99a
--- /dev/null
+++ 
b/core/src/test/scala/org/apache/spark/shuffle/streaming/MultiShuffleManagerSuite.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.streaming
+
+import java.util.Properties
+
+import org.scalatest.matchers.should.Matchers
+
+import org.apache.spark._
+import org.apache.spark.LocalSparkContext.withSpark
+import org.apache.spark.internal.config.SHUFFLE_MANAGER
+import 
org.apache.spark.shuffle.streaming.MultiShuffleManager.{isStreamingShuffleEnabled,
 STREAMING_SHUFFLE_ENABLED_PROPERTY}
+
+class MultiShuffleManagerSuite
+  extends SparkFunSuite
+  with LocalSparkContext
+  with Matchers {
+
+  test("isStreamingShuffleEnabled reflects the per-query property") {
+    val props = new Properties()
+    isStreamingShuffleEnabled(props) should be(false)
+
+    props.setProperty(STREAMING_SHUFFLE_ENABLED_PROPERTY, "true")
+    isStreamingShuffleEnabled(props) should be(true)
+
+    props.setProperty(STREAMING_SHUFFLE_ENABLED_PROPERTY, "false")
+    isStreamingShuffleEnabled(props) should be(false)
+  }
+
+  private def assertRoutesToStreaming(enabled: Boolean): Unit = {
+    withSpark(new SparkContext("local", "MultiShuffleManagerSuite", new 
SparkConf())) { sc =>
+      if (enabled) {
+        sc.setLocalProperty(STREAMING_SHUFFLE_ENABLED_PROPERTY, "true")
+      }
+      val rdd = sc.parallelize(1 to 4).map(x => (x, x))
+      val dep = new ShuffleDependency[Int, Int, Int](rdd, new 
HashPartitioner(2))
+      val handle = new MultiShuffleManager(sc.conf).registerShuffle(7, dep)
+      assert(handle.isInstanceOf[StreamingShuffleHandle[_, _, _]] == enabled)
+    }
+  }
+
+  test("registerShuffle routes to the streaming manager when enabled for the 
query") {
+    assertRoutesToStreaming(enabled = true)
+  }
+
+  test("registerShuffle routes to the sort manager when not enabled for the 
query") {
+    assertRoutesToStreaming(enabled = false)
+  }
+
+  test("SparkEnv initializes the streaming shuffle tracker when 
MultiShuffleManager is set") {
+    val conf = new SparkConf().set(SHUFFLE_MANAGER, 
classOf[MultiShuffleManager].getName)
+    withSpark(new SparkContext("local", "MultiShuffleManagerSuite", conf)) { _ 
=>
+      assert(SparkEnv.get.streamingShuffleOutputTracker.isDefined)
+    }
+  }
+}
diff --git 
a/core/src/test/scala/org/apache/spark/shuffle/streaming/StreamingShuffleManagerSuite.scala
 
b/core/src/test/scala/org/apache/spark/shuffle/streaming/StreamingShuffleManagerSuite.scala
new file mode 100644
index 000000000000..181d779e8bb5
--- /dev/null
+++ 
b/core/src/test/scala/org/apache/spark/shuffle/streaming/StreamingShuffleManagerSuite.scala
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.streaming
+
+import io.netty.buffer.Unpooled
+import org.mockito.Mockito.when
+import org.scalatest.matchers.should.Matchers
+import org.scalatestplus.mockito.MockitoSugar
+
+import org.apache.spark._
+import org.apache.spark.LocalSparkContext.withSpark
+import org.apache.spark.internal.config.SHUFFLE_MANAGER
+import org.apache.spark.network.shuffle.streaming.{DataMessage, 
TerminationAckMessage, TerminationControlMessage}
+import org.apache.spark.shuffle.sort.SortShuffleManager
+import org.apache.spark.shuffle.streaming.StreamingShuffleManager.{getQueryId, 
getWriterId, QUERY_ID_PROPERTY_KEY}
+
+class StreamingShuffleManagerSuite
+  extends SparkFunSuite
+  with LocalSparkContext
+  with Matchers
+  with MockitoSugar {
+
+  private val SQL_EXECUTION_ID_KEY = "spark.sql.execution.id"
+
+  // ---- getWriterId ----
+
+  test("getWriterId returns the writer id for a data message") {
+    val msg = new DataMessage(7, 3, 0, Unpooled.EMPTY_BUFFER, 0L)
+    getWriterId(msg) should be(7)
+  }
+
+  test("getWriterId returns the writer id for a termination control message") {
+    getWriterId(new TerminationControlMessage(5, 2)) should be(5)
+  }
+
+  test("getWriterId throws on an unexpected message type") {
+    val e = intercept[SparkRuntimeException] {
+      getWriterId(new TerminationAckMessage(1, 1))
+    }
+    checkError(
+      e,
+      condition = "STREAMING_SHUFFLE_UNEXPECTED_MESSAGE_TYPE",
+      parameters = Map("messageType" -> "TERMINATION_ACK_MESSAGE"))
+  }
+
+  // ---- getQueryId ----
+
+  test("getQueryId returns the streaming query id when set") {
+    val context = mock[TaskContext]
+    
when(context.getLocalProperty(QUERY_ID_PROPERTY_KEY)).thenReturn("query-123")
+    getQueryId(context) should be("query-123")
+  }
+
+  test("getQueryId falls back to the SQL execution id for batch queries") {
+    val context = mock[TaskContext]
+    when(context.getLocalProperty(SQL_EXECUTION_ID_KEY)).thenReturn("42")
+    getQueryId(context) should be("42")
+  }
+
+  test("getQueryId throws when no query id property is set") {
+    val context = mock[TaskContext]
+    val e = intercept[SparkException] {
+      getQueryId(context)
+    }
+    checkError(
+      e,
+      condition = "INTERNAL_ERROR",
+      parameters = Map("message" ->
+        "Streaming shuffle requires the query id or SQL execution id local 
property to be set"))
+  }
+
+  // ---- registerShuffle ----
+
+  test("registerShuffle returns a StreamingShuffleHandle") {
+    withSpark(new SparkContext("local", "StreamingShuffleManagerSuite", new 
SparkConf())) { sc =>
+      val rdd = sc.parallelize(1 to 4).map(x => (x, x))
+      val dep = new ShuffleDependency[Int, Int, Int](rdd, new 
HashPartitioner(2))
+      val handle = new StreamingShuffleManager().registerShuffle(0, dep)
+      assert(handle.isInstanceOf[StreamingShuffleHandle[_, _, _]])
+    }
+  }
+
+  // ---- SparkEnv tracker initialization gating ----
+
+  private def assertTrackerInitialized(shuffleManager: Option[String], 
expectPresent: Boolean):
+      Unit = {
+    val conf = new SparkConf()
+    shuffleManager.foreach(conf.set(SHUFFLE_MANAGER, _))
+    withSpark(new SparkContext("local", "StreamingShuffleManagerSuite", conf)) 
{ _ =>
+      val tracker = SparkEnv.get.streamingShuffleOutputTracker
+      assert(tracker.isDefined == expectPresent)
+      // On the driver a present tracker is always the master.
+      if (expectPresent) {
+        assert(tracker.get.isInstanceOf[StreamingShuffleOutputTrackerMaster])
+      }
+    }
+  }
+
+  test("SparkEnv initializes the streaming shuffle tracker for 
StreamingShuffleManager") {
+    assertTrackerInitialized(Some(classOf[StreamingShuffleManager].getName), 
expectPresent = true)
+  }
+
+  test("SparkEnv does not initialize the tracker for a non-streaming (sort) 
manager") {
+    assertTrackerInitialized(Some(classOf[SortShuffleManager].getName), 
expectPresent = false)
+  }
+
+  test("SparkEnv does not initialize the tracker for the default manager") {
+    assertTrackerInitialized(None, expectPresent = false)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to