diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java index 20d840baeaf6c..8d46671f9581a 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -124,6 +124,9 @@ public void setClientId(String id) { * to be returned in the same order that they were requested, assuming only a single * TransportClient is used to fetch the chunks. * + * OpenBlocks and following FetchChunk requests for a stream should be sent by the same + * TransportClient to avoid potential memory leak on server side. + * * @param streamId Identifier that refers to a stream in the remote StreamManager. This should * be agreed upon by client and server beforehand. * @param chunkIndex 0-based index of the chunk to fetch diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java index f08d8b0f984cf..43c3d23b6304d 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java @@ -90,7 +90,6 @@ protected void channelRead0( ManagedBuffer buf; try { streamManager.checkAuthorization(client, msg.streamChunkId.streamId); - streamManager.registerChannel(channel, msg.streamChunkId.streamId); buf = streamManager.getChunk(msg.streamChunkId.streamId, msg.streamChunkId.chunkIndex); } catch (Exception e) { logger.error(String.format("Error opening block %s for request from %s", diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java index 0f6a8824d95e5..3e1f77cdc20ed 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java @@ -23,6 +23,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicLong; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import io.netty.channel.Channel; import org.apache.commons.lang3.tuple.ImmutablePair; @@ -202,4 +203,8 @@ public long registerStream(String appId, Iterator<ManagedBuffer> buffers) { return myStreamId; } + @VisibleForTesting + public long getStreamCount() { + return streams.size(); + } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index 098fa7974b87b..c59433c44c5c0 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -92,6 +92,7 @@ protected void handleMessage( checkAuth(client, msg.appId); long streamId = streamManager.registerStream(client.getClientId(), new ManagedBufferIterator(msg.appId, msg.execId, msg.blockIds)); + streamManager.registerChannel(client.getChannel(), streamId); if (logger.isTraceEnabled()) { logger.trace("Registered streamId {} with {} buffers for client {} from host {}", streamId, diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/StreamStatesCleanupSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/StreamStatesCleanupSuite.java new file mode 100644 index 0000000000000..4a721260e3db3 --- /dev/null +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/StreamStatesCleanupSuite.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.nio.ByteBuffer; + +import io.netty.channel.Channel; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportResponseHandler; +import org.apache.spark.network.server.OneForOneStreamManager; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.shuffle.protocol.OpenBlocks; + +public class StreamStatesCleanupSuite { + + @Test + public void testStreamsAreRemovedCorrectly() { + OneForOneStreamManager streamManager = new OneForOneStreamManager(); + ExternalShuffleBlockResolver blockResolver = mock(ExternalShuffleBlockResolver.class); + TransportClient reverseClient + = new TransportClient(mock(Channel.class), mock(TransportResponseHandler.class)); + RpcHandler handler = new ExternalShuffleBlockHandler(streamManager, blockResolver); + + ManagedBuffer block0Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[3])); + ManagedBuffer block1Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[7])); + when(blockResolver.getBlockData("app0", "exec1", 0, 0, 0)) + .thenReturn(block0Marker); + when(blockResolver.getBlockData("app0", "exec1", 0, 0, 1)) + .thenReturn(block1Marker); + ByteBuffer openBlocks = new OpenBlocks("app0", "exec1", + new String[]{"shuffle_0_0_0", "shuffle_0_0_1"}) + .toByteBuffer(); + + RpcResponseCallback callback = mock(RpcResponseCallback.class); + + // Open blocks + handler.receive(reverseClient, openBlocks, callback); + assertEquals(1, streamManager.getStreamCount()); + + // Connection closed before any FetchChunk request received + streamManager.connectionTerminated(reverseClient.getChannel()); + assertEquals(0, streamManager.getStreamCount()); + } + +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index 7076701421e2e..a72c8dc8a5485 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -61,6 +61,7 @@ class NettyBlockRpcServer( yield blockManager.getBlockData(BlockId.apply(openBlocks.blockIds(i))) val streamId = streamManager.registerStream(appId, blocks.iterator.asJava) logTrace(s"Registered streamId $streamId with $blocksNum buffers") + streamManager.registerChannel(client.getChannel, streamId) responseContext.onSuccess(new StreamHandle(streamId, blocksNum).toByteBuffer) case uploadBlock: UploadBlock => diff --git a/core/src/test/scala/org/apache/spark/network/netty/StreamsStateCleanupSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/StreamsStateCleanupSuite.scala new file mode 100644 index 0000000000000..054cc7fff3596 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/network/netty/StreamsStateCleanupSuite.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import io.netty.channel.Channel +import org.scalatest.mockito.MockitoSugar + +import org.apache.spark.SparkFunSuite +import org.apache.spark.network.BlockDataManager +import org.apache.spark.network.client.{RpcResponseCallback, TransportClient, TransportResponseHandler} +import org.apache.spark.network.server.OneForOneStreamManager +import org.apache.spark.network.shuffle.protocol.OpenBlocks +import org.apache.spark.serializer.Serializer + +class StreamsStateCleanupSuite extends SparkFunSuite with MockitoSugar { + + test("test streams are removed correctly") { + val streamManager = new OneForOneStreamManager() + val reverseClient = new TransportClient(mock[Channel], mock[TransportResponseHandler]) + val rpcHandler + = new NettyBlockRpcServer("app0", mock[Serializer], mock[BlockDataManager]) + + val openBlocks = new OpenBlocks("app0", "exec1", + Array[String]("shuffle_0_0_0", "shuffle_0_0_1")) + .toByteBuffer + val callback = mock[RpcResponseCallback] + + // Open blocks + rpcHandler.receive(reverseClient, openBlocks, callback) + assert(streamManager.getStreamCount === 1) + + // Connection closed before any FetchChunk request received + streamManager.connectionTerminated(reverseClient.getChannel) + assert(streamManager.getStreamCount === 0) + } +}
With regards, Apache Git Services --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org