diff --git 
a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java
 
b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java
index 20d840baeaf6c..8d46671f9581a 100644
--- 
a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java
+++ 
b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java
@@ -124,6 +124,9 @@ public void setClientId(String id) {
    * to be returned in the same order that they were requested, assuming only 
a single
    * TransportClient is used to fetch the chunks.
    *
+   * OpenBlocks and following FetchChunk requests for a stream should be sent 
by the same
+   * TransportClient to avoid potential memory leak on server side.
+   *
    * @param streamId Identifier that refers to a stream in the remote 
StreamManager. This should
    *                 be agreed upon by client and server beforehand.
    * @param chunkIndex 0-based index of the chunk to fetch
diff --git 
a/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java
 
b/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java
index f08d8b0f984cf..43c3d23b6304d 100644
--- 
a/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java
+++ 
b/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java
@@ -90,7 +90,6 @@ protected void channelRead0(
     ManagedBuffer buf;
     try {
       streamManager.checkAuthorization(client, msg.streamChunkId.streamId);
-      streamManager.registerChannel(channel, msg.streamChunkId.streamId);
       buf = streamManager.getChunk(msg.streamChunkId.streamId, 
msg.streamChunkId.chunkIndex);
     } catch (Exception e) {
       logger.error(String.format("Error opening block %s for request from %s",
diff --git 
a/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
 
b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
index 0f6a8824d95e5..3e1f77cdc20ed 100644
--- 
a/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
+++ 
b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
@@ -23,6 +23,7 @@
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.atomic.AtomicLong;
 
+import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Preconditions;
 import io.netty.channel.Channel;
 import org.apache.commons.lang3.tuple.ImmutablePair;
@@ -202,4 +203,8 @@ public long registerStream(String appId, 
Iterator<ManagedBuffer> buffers) {
     return myStreamId;
   }
 
+  @VisibleForTesting
+  public long getStreamCount() {
+    return streams.size();
+  }
 }
diff --git 
a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
 
b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
index 098fa7974b87b..c59433c44c5c0 100644
--- 
a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
+++ 
b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
@@ -92,6 +92,7 @@ protected void handleMessage(
         checkAuth(client, msg.appId);
         long streamId = streamManager.registerStream(client.getClientId(),
           new ManagedBufferIterator(msg.appId, msg.execId, msg.blockIds));
+        streamManager.registerChannel(client.getChannel(), streamId);
         if (logger.isTraceEnabled()) {
           logger.trace("Registered streamId {} with {} buffers for client {} 
from host {}",
                        streamId,
diff --git 
a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/StreamStatesCleanupSuite.java
 
b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/StreamStatesCleanupSuite.java
new file mode 100644
index 0000000000000..4a721260e3db3
--- /dev/null
+++ 
b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/StreamStatesCleanupSuite.java
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.nio.ByteBuffer;
+
+import io.netty.channel.Channel;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NioManagedBuffer;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportResponseHandler;
+import org.apache.spark.network.server.OneForOneStreamManager;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.shuffle.protocol.OpenBlocks;
+
+public class StreamStatesCleanupSuite {
+
+  @Test
+  public void testStreamsAreRemovedCorrectly() {
+    OneForOneStreamManager streamManager = new OneForOneStreamManager();
+    ExternalShuffleBlockResolver blockResolver = 
mock(ExternalShuffleBlockResolver.class);
+    TransportClient reverseClient
+      = new TransportClient(mock(Channel.class), 
mock(TransportResponseHandler.class));
+    RpcHandler handler = new ExternalShuffleBlockHandler(streamManager, 
blockResolver);
+
+    ManagedBuffer block0Marker = new NioManagedBuffer(ByteBuffer.wrap(new 
byte[3]));
+    ManagedBuffer block1Marker = new NioManagedBuffer(ByteBuffer.wrap(new 
byte[7]));
+    when(blockResolver.getBlockData("app0", "exec1", 0, 0, 0))
+      .thenReturn(block0Marker);
+    when(blockResolver.getBlockData("app0", "exec1", 0, 0, 1))
+      .thenReturn(block1Marker);
+    ByteBuffer openBlocks = new OpenBlocks("app0", "exec1",
+      new String[]{"shuffle_0_0_0", "shuffle_0_0_1"})
+      .toByteBuffer();
+
+    RpcResponseCallback callback = mock(RpcResponseCallback.class);
+
+    // Open blocks
+    handler.receive(reverseClient, openBlocks, callback);
+    assertEquals(1, streamManager.getStreamCount());
+
+    // Connection closed before any FetchChunk request received
+    streamManager.connectionTerminated(reverseClient.getChannel());
+    assertEquals(0, streamManager.getStreamCount());
+  }
+
+}
diff --git 
a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala 
b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
index 7076701421e2e..a72c8dc8a5485 100644
--- 
a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
+++ 
b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
@@ -61,6 +61,7 @@ class NettyBlockRpcServer(
           yield 
blockManager.getBlockData(BlockId.apply(openBlocks.blockIds(i)))
         val streamId = streamManager.registerStream(appId, 
blocks.iterator.asJava)
         logTrace(s"Registered streamId $streamId with $blocksNum buffers")
+        streamManager.registerChannel(client.getChannel, streamId)
         responseContext.onSuccess(new StreamHandle(streamId, 
blocksNum).toByteBuffer)
 
       case uploadBlock: UploadBlock =>
diff --git 
a/core/src/test/scala/org/apache/spark/network/netty/StreamsStateCleanupSuite.scala
 
b/core/src/test/scala/org/apache/spark/network/netty/StreamsStateCleanupSuite.scala
new file mode 100644
index 0000000000000..054cc7fff3596
--- /dev/null
+++ 
b/core/src/test/scala/org/apache/spark/network/netty/StreamsStateCleanupSuite.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty
+
+import io.netty.channel.Channel
+import org.scalatest.mockito.MockitoSugar
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.network.BlockDataManager
+import org.apache.spark.network.client.{RpcResponseCallback, TransportClient, 
TransportResponseHandler}
+import org.apache.spark.network.server.OneForOneStreamManager
+import org.apache.spark.network.shuffle.protocol.OpenBlocks
+import org.apache.spark.serializer.Serializer
+
+class StreamsStateCleanupSuite extends SparkFunSuite with MockitoSugar {
+
+  test("test streams are removed correctly") {
+    val streamManager = new OneForOneStreamManager()
+    val reverseClient = new TransportClient(mock[Channel], 
mock[TransportResponseHandler])
+    val rpcHandler
+      = new NettyBlockRpcServer("app0", mock[Serializer], 
mock[BlockDataManager])
+
+    val openBlocks = new OpenBlocks("app0", "exec1",
+      Array[String]("shuffle_0_0_0", "shuffle_0_0_1"))
+      .toByteBuffer
+    val callback = mock[RpcResponseCallback]
+
+    // Open blocks
+    rpcHandler.receive(reverseClient, openBlocks, callback)
+    assert(streamManager.getStreamCount === 1)
+
+    // Connection closed before any FetchChunk request received
+    streamManager.connectionTerminated(reverseClient.getChannel)
+    assert(streamManager.getStreamCount === 0)
+  }
+}


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to