This is an automated email from the ASF dual-hosted git repository.

mridulm80 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 1d02ca49976 [SPARK-43987][SHUFFLE] Separate finalizeShuffleMerge 
Processing to Dedicated Thread Pools
1d02ca49976 is described below

commit 1d02ca499769794082b595617017fb951ce118ad
Author: Shu Wang <[email protected]>
AuthorDate: Fri Aug 11 19:04:32 2023 -0500

    [SPARK-43987][SHUFFLE] Separate finalizeShuffleMerge Processing to 
Dedicated Thread Pools
    
    ### What changes were proposed in this pull request?
    
    In this PR, we proposed to separate finalizeShuffleMerge processing into 
dedicated thread pools.
    1. We introduce `ShuffleTransportContext` extend `TransportContext`.
    2. We override the  `initializePipeline` by adding `FinalizeHandler` when 
the newly added configuration 
`spark.shuffle.server.finalizeShuffleMergeThreads` is positive.
    3. We override the `decode` within `ShuffleMessageDecoder` so that an  
`FINALIZE_SHUFFLE_MERGE` type `RpcRequest` will not be processed by current IO 
threads. We will further encapsulate it as `RpcRequestInternal`.
    4. A dedicated `FinalizedHandler` will be attached to the channel pipeline, 
which only handles `RpcRequestInternal` type.
    
    authors: otterc shuwang21
    ### Why are the changes needed?
    
    In our production environment, `finalizeShuffleMerge` processing took 
longer time (p90 is around 20s) than other PRC requests. This is due to 
`finalizeShuffleMerge` invoking IO operations like truncate and file open/close.
    
    More importantly, processing this `finalizeShuffleMerge` can block other 
critical lightweight messages like authentications, which can cause 
authentication timeout as well as fetch failures. Those timeout and fetch 
failures affect the stability of the Spark job executions.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    ### How was this patch tested?
    
    1. We add a few related UTs.
    2. We also tested internally with this patch. Under the same setting, 
without this patch, the shuffle server can encounter authentication timeouts 
thus fetch failures.  With this path, `finalizeShuffleMerge` is processed by  
`finalizeHandler` we introduced, and authentication RPC requests are not 
blocked.
    3. In terms of performance, we have deployed internally. We saw Shuffle 
fetch delays have improved. P80 is reduced by 98.1% (120 s to 2.3 s). P90 is 
reduced by 83% (30mins to ~5mins). P99 is reduced by by 70% (44h to 13h). 
Furthermore, we saw The SASL timeout exceptions are reduced by ~40%. Spark job 
runtime is also improved. P50 is reduced by 35.2% (164s to 107s). P80 is 
reduced by 50.1% (15 mins s to 7.7mins). P90 is reduced by 45.5% (37 mins to 20 
mins). P99 is reduced by 10% (4.5h [...]
    
    Closes #41489 from shuwang21/SPARK-43987.
    
    Lead-authored-by: Shu Wang <[email protected]>
    Co-authored-by: Shu Wang <[email protected]>
    Co-authored-by: Chandni Singh <[email protected]>
    Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
---
 .../org/apache/spark/network/TransportContext.java |   9 +-
 .../spark/network/protocol/MessageWithHeader.java  |   3 +-
 .../network/server/TransportChannelHandler.java    |   4 +
 .../apache/spark/network/util/TransportConf.java   |  31 +++-
 .../network/shuffle/ShuffleTransportContext.java   | 195 +++++++++++++++++++++
 .../shuffle/ShuffleTransportContextSuite.java      | 140 +++++++++++++++
 .../spark/network/yarn/YarnShuffleService.java     |   3 +-
 7 files changed, 381 insertions(+), 4 deletions(-)

diff --git 
a/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java
 
b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java
index b885bee7032..51d074a4ddb 100644
--- 
a/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java
+++ 
b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java
@@ -17,6 +17,9 @@
 
 package org.apache.spark.network;
 
+import io.netty.buffer.ByteBuf;
+import io.netty.handler.codec.MessageToMessageDecoder;
+
 import java.io.Closeable;
 import java.util.ArrayList;
 import java.util.List;
@@ -196,7 +199,7 @@ public class TransportContext implements Closeable {
       pipeline
         .addLast("encoder", ENCODER)
         .addLast(TransportFrameDecoder.HANDLER_NAME, 
NettyUtils.createFrameDecoder())
-        .addLast("decoder", DECODER)
+        .addLast("decoder", getDecoder())
         .addLast("idleStateHandler",
           new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000))
         // NOTE: Chunks are currently guaranteed to be returned in the order 
of request, but this
@@ -216,6 +219,10 @@ public class TransportContext implements Closeable {
     }
   }
 
+  protected MessageToMessageDecoder<ByteBuf> getDecoder() {
+    return DECODER;
+  }
+
   /**
    * Creates the server- and client-side handler which is used to handle both 
RequestMessages and
    * ResponseMessages. The channel is expected to have been successfully 
created, though certain
diff --git 
a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java
 
b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java
index dfcb1c642eb..de2c44925f6 100644
--- 
a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java
+++ 
b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java
@@ -35,7 +35,8 @@ import org.apache.spark.network.util.AbstractFileRegion;
  *
  * The header must be a ByteBuf, while the body can be a ByteBuf or a 
FileRegion.
  */
-class MessageWithHeader extends AbstractFileRegion {
+
+public class MessageWithHeader extends AbstractFileRegion {
 
   @Nullable private final ManagedBuffer managedBuffer;
   private final ByteBuf header;
diff --git 
a/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
 
b/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
index d197032003e..f55ca2204cd 100644
--- 
a/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
+++ 
b/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java
@@ -184,6 +184,10 @@ public class TransportChannelHandler extends 
SimpleChannelInboundHandler<Message
     return responseHandler;
   }
 
+  public TransportRequestHandler getRequestHandler() {
+    return requestHandler;
+  }
+
   @Override
   public void channelRegistered(ChannelHandlerContext ctx) throws Exception {
     transportContext.getRegisteredConnections().inc();
diff --git 
a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java
 
b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java
index 45e9994be72..2794883f3cf 100644
--- 
a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java
+++ 
b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java
@@ -20,7 +20,7 @@ package org.apache.spark.network.util;
 import java.util.Locale;
 import java.util.Properties;
 import java.util.concurrent.TimeUnit;
-
+import com.google.common.base.Preconditions;
 import com.google.common.primitives.Ints;
 import io.netty.util.NettyRuntime;
 
@@ -324,6 +324,35 @@ public class TransportConf {
     return conf.getInt("spark.shuffle.server.chunkFetchHandlerThreadsPercent", 
0) > 0;
   }
 
+  /**
+   * Percentage of io.serverThreads used by netty to process 
FinalizeShuffleMerge. When the config
+   * `spark.shuffle.server.finalizeShuffleMergeThreadsPercent` is set, shuffle 
server will use a
+   * separate EventLoopGroup to process FinalizeShuffleMerge messages, which 
are I/O intensive and
+   * could take long time to process due to disk contentions. The number of 
threads used for
+   * handling finalizeShuffleMerge requests are percentage of io.serverThreads 
(if defined) else it
+   * is a percentage of 2 * #cores.
+   */
+  public int finalizeShuffleMergeHandlerThreads() {
+    if (!this.getModuleName().equalsIgnoreCase("shuffle")) {
+      return 0;
+    }
+    Preconditions.checkArgument(separateFinalizeShuffleMerge(),
+        "Please set spark.shuffle.server.finalizeShuffleMergeThreadsPercent to 
a positive value");
+    int finalizeShuffleMergeThreadsPercent =
+        
Integer.parseInt(conf.get("spark.shuffle.server.finalizeShuffleMergeThreadsPercent"));
+    int threads =
+        this.serverThreads() > 0 ? this.serverThreads() : 2 * 
NettyRuntime.availableProcessors();
+    return (int) Math.ceil(threads * (finalizeShuffleMergeThreadsPercent / 
100.0));
+  }
+
+  /**
+   * Whether to use a separate EventLoopGroup to process FinalizeShuffleMerge 
messages, it is
+   * decided by the config 
`spark.shuffle.server.finalizeShuffleMergeThreadsPercent` is set or not.
+   */
+  public boolean separateFinalizeShuffleMerge() {
+    return 
conf.getInt("spark.shuffle.server.finalizeShuffleMergeThreadsPercent", 0) > 0;
+  }
+
   /**
    * Whether to use the old protocol while doing the shuffle block fetching.
    * It is only enabled while we need the compatibility in the scenario of new 
spark version
diff --git 
a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleTransportContext.java
 
b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleTransportContext.java
new file mode 100644
index 00000000000..39ddf2c2a7e
--- /dev/null
+++ 
b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleTransportContext.java
@@ -0,0 +1,195 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.nio.ByteBuffer;
+import java.util.List;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.EventLoopGroup;
+import io.netty.channel.SimpleChannelInboundHandler;
+import io.netty.channel.socket.SocketChannel;
+import io.netty.handler.codec.MessageToMessageDecoder;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.TransportContext;
+import org.apache.spark.network.protocol.Message;
+import org.apache.spark.network.protocol.MessageDecoder;
+import org.apache.spark.network.protocol.RpcRequest;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.TransportChannelHandler;
+import org.apache.spark.network.server.TransportRequestHandler;
+import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
+import org.apache.spark.network.util.IOMode;
+import org.apache.spark.network.util.NettyUtils;
+import org.apache.spark.network.util.TransportConf;
+
+import static org.apache.spark.network.util.NettyUtils.getRemoteAddress;
+
+/**
+ * Extends {@link TransportContext} to support customized shuffle service. 
Specifically, we
+ * modified the Netty Channel Pipeline so that IO expensive messages such as 
FINALIZE_SHUFFLE_MERGE
+ * are processed in the separate handlers.
+ * */
+public class ShuffleTransportContext extends TransportContext {
+  private static final Logger logger = 
LoggerFactory.getLogger(ShuffleTransportContext.class);
+  private static final ShuffleMessageDecoder SHUFFLE_DECODER =
+      new ShuffleMessageDecoder(MessageDecoder.INSTANCE);
+  private final EventLoopGroup finalizeWorkers;
+
+  public ShuffleTransportContext(
+    TransportConf conf,
+    ExternalBlockHandler rpcHandler,
+    boolean closeIdleConnections) {
+    this(conf, rpcHandler, closeIdleConnections, false);
+  }
+
+  public ShuffleTransportContext(TransportConf conf,
+      RpcHandler rpcHandler,
+      boolean closeIdleConnections,
+      boolean isClientOnly) {
+    super(conf, rpcHandler, closeIdleConnections, isClientOnly);
+
+    if ("shuffle".equalsIgnoreCase(conf.getModuleName()) && 
conf.separateFinalizeShuffleMerge()) {
+      finalizeWorkers = NettyUtils.createEventLoop(
+          IOMode.valueOf(conf.ioMode()),
+          conf.finalizeShuffleMergeHandlerThreads(),
+          "shuffle-finalize-merge-handler");
+      logger.info("finalize shuffle merged workers created");
+    } else {
+      finalizeWorkers = null;
+    }
+  }
+
+  @Override
+  public TransportChannelHandler initializePipeline(SocketChannel channel) {
+    TransportChannelHandler ch = super.initializePipeline(channel);
+    addHandlerToPipeline(channel, ch);
+    return ch;
+  }
+
+  @Override
+  public TransportChannelHandler initializePipeline(SocketChannel channel,
+      RpcHandler channelRpcHandler) {
+    TransportChannelHandler ch = super.initializePipeline(channel, 
channelRpcHandler);
+    addHandlerToPipeline(channel, ch);
+    return ch;
+  }
+
+  /**
+   * Add finalize handler to pipeline if needed. This is needed only when
+   * separateFinalizeShuffleMerge is enabled.
+   */
+  private void addHandlerToPipeline(SocketChannel channel,
+      TransportChannelHandler transportChannelHandler) {
+    if (finalizeWorkers != null) {
+      channel.pipeline().addLast(finalizeWorkers, 
FinalizedHandler.HANDLER_NAME,
+        new FinalizedHandler(transportChannelHandler.getRequestHandler()));
+    }
+  }
+
+  @Override
+  protected MessageToMessageDecoder<ByteBuf> getDecoder() {
+    return finalizeWorkers == null ? super.getDecoder() : SHUFFLE_DECODER;
+  }
+
+  static class ShuffleMessageDecoder extends MessageToMessageDecoder<ByteBuf> {
+
+    private final MessageDecoder delegate;
+    ShuffleMessageDecoder(MessageDecoder delegate) {
+      super();
+      this.delegate = delegate;
+    }
+
+    /**
+     * Decode the message and check if it is a finalize merge request. If yes, 
then create an
+     * internal rpc request message and add it to the list of messages to be 
handled by
+     * {@link TransportChannelHandler}
+    */
+    @Override
+    protected void decode(ChannelHandlerContext channelHandlerContext,
+        ByteBuf byteBuf,
+        List<Object> list) throws Exception {
+      delegate.decode(channelHandlerContext, byteBuf, list);
+      Object msg = list.get(list.size() - 1);
+      if (msg instanceof RpcRequest) {
+        RpcRequest req = (RpcRequest) msg;
+        ByteBuffer buffer = req.body().nioByteBuffer();
+        byte type = Unpooled.wrappedBuffer(buffer).readByte();
+        if (type == BlockTransferMessage.Type.FINALIZE_SHUFFLE_MERGE.id()) {
+          list.remove(list.size() - 1);
+          RpcRequestInternal rpcRequestInternal =
+            new 
RpcRequestInternal(BlockTransferMessage.Type.FINALIZE_SHUFFLE_MERGE, req);
+          logger.trace("Created internal rpc request msg with rpcId {} for 
finalize merge req",
+            req.requestId);
+          list.add(rpcRequestInternal);
+        }
+      }
+    }
+  }
+
+  /**
+   * Internal message to handle rpc requests that is not accepted by
+   * {@link TransportChannelHandler} as this message doesn't extend {@link 
Message}. It will be
+   * accepted by {@link FinalizedHandler} instead, which is configured to 
execute in a separate
+   * EventLoopGroup.
+   */
+  static class RpcRequestInternal {
+    public final BlockTransferMessage.Type messageType;
+    public final RpcRequest rpcRequest;
+
+    RpcRequestInternal(BlockTransferMessage.Type messageType,
+        RpcRequest rpcRequest) {
+      this.messageType = messageType;
+      this.rpcRequest = rpcRequest;
+    }
+  }
+
+  static class FinalizedHandler extends 
SimpleChannelInboundHandler<RpcRequestInternal> {
+    private static final Logger logger = 
LoggerFactory.getLogger(FinalizedHandler.class);
+    public static final String HANDLER_NAME = "finalizeHandler";
+    private final TransportRequestHandler transportRequestHandler;
+
+    @Override
+    public boolean acceptInboundMessage(Object msg) throws Exception {
+      if (msg instanceof RpcRequestInternal) {
+        RpcRequestInternal rpcRequestInternal = (RpcRequestInternal) msg;
+        return rpcRequestInternal.messageType == 
BlockTransferMessage.Type.FINALIZE_SHUFFLE_MERGE;
+      }
+      return false;
+    }
+
+    FinalizedHandler(TransportRequestHandler transportRequestHandler) {
+      this.transportRequestHandler = transportRequestHandler;
+    }
+
+    @Override
+    protected void channelRead0(ChannelHandlerContext channelHandlerContext,
+        RpcRequestInternal req) throws Exception {
+      if (logger.isTraceEnabled()) {
+        logger.trace("Finalize shuffle req from {} for rpc request {}",
+                getRemoteAddress(channelHandlerContext.channel()), 
req.rpcRequest.requestId);
+      }
+      this.transportRequestHandler.handle(req.rpcRequest);
+    }
+  }
+}
diff --git 
a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleTransportContextSuite.java
 
b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleTransportContextSuite.java
new file mode 100644
index 00000000000..1c8c5b33bd9
--- /dev/null
+++ 
b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleTransportContextSuite.java
@@ -0,0 +1,140 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import com.google.common.collect.Lists;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.ByteBufAllocator;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.socket.nio.NioSocketChannel;
+import io.netty.channel.socket.SocketChannel;
+
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+import org.apache.spark.network.buffer.NioManagedBuffer;
+import org.apache.spark.network.protocol.Message;
+import org.apache.spark.network.protocol.MessageEncoder;
+import org.apache.spark.network.protocol.MessageWithHeader;
+import org.apache.spark.network.protocol.RpcRequest;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
+import org.apache.spark.network.shuffle.protocol.FinalizeShuffleMerge;
+import org.apache.spark.network.shuffle.protocol.OpenBlocks;
+import org.apache.spark.network.util.ByteArrayWritableChannel;
+import org.apache.spark.network.util.MapConfigProvider;
+import org.apache.spark.network.util.TransportConf;
+
+public class ShuffleTransportContextSuite {
+
+  private ExternalBlockHandler blockHandler;
+
+  @Before
+  public void before() throws IOException {
+    blockHandler = mock(ExternalBlockHandler.class);
+  }
+
+  ShuffleTransportContext createShuffleTransportContext(boolean 
separateFinalizeThread)
+      throws IOException {
+    Map<String, String> configs = new HashMap<>();
+    configs.put("spark.shuffle.server.finalizeShuffleMergeThreadsPercent",
+        separateFinalizeThread ? "1" : "0");
+    TransportConf transportConf = new TransportConf("shuffle",
+        new MapConfigProvider(configs));
+    return new ShuffleTransportContext(transportConf, blockHandler, true);
+  }
+
+  private ByteBuf getDecodableMessageBuf(Message req) throws Exception {
+    List<Object> out = Lists.newArrayList();
+    ChannelHandlerContext context = mock(ChannelHandlerContext.class);
+    when(context.alloc()).thenReturn(ByteBufAllocator.DEFAULT);
+    MessageEncoder.INSTANCE.encode(context, req, out);
+    MessageWithHeader msgWithHeader = (MessageWithHeader) out.remove(0);
+    ByteArrayWritableChannel writableChannel =
+      new ByteArrayWritableChannel((int) msgWithHeader.count());
+    while (msgWithHeader.transfered() < msgWithHeader.count()) {
+      msgWithHeader.transferTo(writableChannel, msgWithHeader.transfered());
+    }
+    ByteBuf messageBuf = Unpooled.wrappedBuffer(writableChannel.getData());
+    messageBuf.readLong(); // frame length
+    return messageBuf;
+  }
+
+  @Test
+  public void testInitializePipeline() throws IOException {
+    // SPARK-43987: test that the FinalizedHandler is added to the pipeline 
only when configured
+    for (boolean enabled : new boolean[]{true, false}) {
+      ShuffleTransportContext ctx = createShuffleTransportContext(enabled);
+      SocketChannel channel = new NioSocketChannel();
+      RpcHandler rpcHandler = mock(RpcHandler.class);
+      ctx.initializePipeline(channel, rpcHandler);
+      String handlerName = 
ShuffleTransportContext.FinalizedHandler.HANDLER_NAME;
+      if (enabled) {
+        Assert.assertNotNull(channel.pipeline().get(handlerName));
+      } else {
+        Assert.assertNull(channel.pipeline().get(handlerName));
+      }
+    }
+  }
+
+  @Test
+  public void testDecodeOfFinalizeShuffleMessage() throws Exception {
+    // SPARK-43987: test FinalizeShuffleMerge message is decoded correctly
+    FinalizeShuffleMerge finalizeRequest = new FinalizeShuffleMerge("app0", 1, 
2, 3);
+    RpcRequest rpcRequest = new RpcRequest(1, new 
NioManagedBuffer(finalizeRequest.toByteBuffer()));
+    ByteBuf messageBuf = getDecodableMessageBuf(rpcRequest);
+    ShuffleTransportContext shuffleTransportContext = 
createShuffleTransportContext(true);
+    ShuffleTransportContext.ShuffleMessageDecoder decoder =
+        (ShuffleTransportContext.ShuffleMessageDecoder) 
shuffleTransportContext.getDecoder();
+    List<Object> out = Lists.newArrayList();
+    decoder.decode(mock(ChannelHandlerContext.class), messageBuf, out);
+
+    Assert.assertEquals(1, out.size());
+    Assert.assertTrue(out.get(0) instanceof 
ShuffleTransportContext.RpcRequestInternal);
+    Assert.assertEquals(BlockTransferMessage.Type.FINALIZE_SHUFFLE_MERGE,
+        ((ShuffleTransportContext.RpcRequestInternal) out.get(0)).messageType);
+  }
+
+  @Test
+  public void testDecodeOfAnyOtherRpcMessage() throws Exception {
+    // SPARK-43987: test any other RPC message is decoded correctly
+    OpenBlocks openBlocks = new OpenBlocks("app0", "1", new String[]{"block1", 
"block2"});
+    RpcRequest rpcRequest = new RpcRequest(1, new 
NioManagedBuffer(openBlocks.toByteBuffer()));
+    ByteBuf messageBuf = getDecodableMessageBuf(rpcRequest);
+    ShuffleTransportContext shuffleTransportContext = 
createShuffleTransportContext(true);
+    ShuffleTransportContext.ShuffleMessageDecoder decoder =
+        (ShuffleTransportContext.ShuffleMessageDecoder) 
shuffleTransportContext.getDecoder();
+    List<Object> out = Lists.newArrayList();
+    decoder.decode(mock(ChannelHandlerContext.class), messageBuf, out);
+
+    Assert.assertEquals(1, out.size());
+    Assert.assertTrue(out.get(0) instanceof RpcRequest);
+    Assert.assertEquals(rpcRequest.requestId, ((RpcRequest) 
out.get(0)).requestId);
+  }
+}
diff --git 
a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java
 
b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java
index b34ebf6e29b..b9b9568aa47 100644
--- 
a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java
+++ 
b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java
@@ -46,6 +46,7 @@ import org.apache.hadoop.yarn.server.api.*;
 import org.apache.spark.network.shuffle.Constants;
 import org.apache.spark.network.shuffle.MergedShuffleFileManager;
 import org.apache.spark.network.shuffle.NoOpMergedShuffleFileManager;
+import org.apache.spark.network.shuffle.ShuffleTransportContext;
 import org.apache.spark.network.shuffledb.DB;
 import org.apache.spark.network.shuffledb.DBBackend;
 import org.apache.spark.network.shuffledb.DBIterator;
@@ -300,7 +301,7 @@ public class YarnShuffleService extends AuxiliaryService {
 
       int port = _conf.getInt(
         SPARK_SHUFFLE_SERVICE_PORT_KEY, DEFAULT_SPARK_SHUFFLE_SERVICE_PORT);
-      transportContext = new TransportContext(transportConf, blockHandler, 
true);
+      transportContext = new ShuffleTransportContext(transportConf, 
blockHandler, true);
       shuffleServer = transportContext.createServer(port, bootstraps);
       // the port should normally be fixed, but for tests its useful to find 
an open port
       port = shuffleServer.getPort();


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to