This is an automated email from the ASF dual-hosted git repository.

vanzin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 52a180f  [SPARK-26674][CORE] Consolidate CompositeByteBuf when reading 
large frame
52a180f is described below

commit 52a180f25fced92cb504a7e013dcc3b7c1f86625
Author: liupengcheng <[email protected]>
AuthorDate: Mon Feb 25 16:40:46 2019 -0800

    [SPARK-26674][CORE] Consolidate CompositeByteBuf when reading large frame
    
    ## What changes were proposed in this pull request?
    
    Currently, TransportFrameDecoder will not consolidate the buffers read from 
network which may cause memory waste. Actually, bytebuf's writtenIndex is far 
less than it's capacity  in most cases, so we can optimize it by doing 
consolidation.
    
    This PR will do this optimization.
    
    Related codes:
    
https://github.com/apache/spark/blob/9a30e23211e165a44acc0dbe19693950f7a7cc73/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java#L143
    
    ## How was this patch tested?
    
    UT
    
    Please review http://spark.apache.org/contributing.html before opening a 
pull request.
    
    Closes #23602 from liupc/Reduce-memory-consumption-in-TransportFrameDecoder.
    
    Lead-authored-by: liupengcheng <[email protected]>
    Co-authored-by: Liupengcheng <[email protected]>
    Signed-off-by: Marcelo Vanzin <[email protected]>
---
 .../spark/network/util/TransportFrameDecoder.java  | 77 +++++++++++++++++-----
 .../network/util/TransportFrameDecoderSuite.java   | 67 +++++++++++++++++++
 2 files changed, 127 insertions(+), 17 deletions(-)

diff --git 
a/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java
 
b/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java
index 8e73ab0..1980361 100644
--- 
a/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java
+++ 
b/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java
@@ -19,6 +19,7 @@ package org.apache.spark.network.util;
 
 import java.util.LinkedList;
 
+import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Preconditions;
 import io.netty.buffer.ByteBuf;
 import io.netty.buffer.CompositeByteBuf;
@@ -48,14 +49,30 @@ public class TransportFrameDecoder extends 
ChannelInboundHandlerAdapter {
   private static final int LENGTH_SIZE = 8;
   private static final int MAX_FRAME_SIZE = Integer.MAX_VALUE;
   private static final int UNKNOWN_FRAME_SIZE = -1;
+  private static final long CONSOLIDATE_THRESHOLD = 20 * 1024 * 1024;
 
   private final LinkedList<ByteBuf> buffers = new LinkedList<>();
   private final ByteBuf frameLenBuf = Unpooled.buffer(LENGTH_SIZE, 
LENGTH_SIZE);
+  private final long consolidateThreshold;
+
+  private CompositeByteBuf frameBuf = null;
+  private long consolidatedFrameBufSize = 0;
+  private int consolidatedNumComponents = 0;
 
   private long totalSize = 0;
   private long nextFrameSize = UNKNOWN_FRAME_SIZE;
+  private int frameRemainingBytes = UNKNOWN_FRAME_SIZE;
   private volatile Interceptor interceptor;
 
+  public TransportFrameDecoder() {
+    this(CONSOLIDATE_THRESHOLD);
+  }
+
+  @VisibleForTesting
+  TransportFrameDecoder(long consolidateThreshold) {
+    this.consolidateThreshold = consolidateThreshold;
+  }
+
   @Override
   public void channelRead(ChannelHandlerContext ctx, Object data) throws 
Exception {
     ByteBuf in = (ByteBuf) data;
@@ -123,30 +140,56 @@ public class TransportFrameDecoder extends 
ChannelInboundHandlerAdapter {
 
   private ByteBuf decodeNext() {
     long frameSize = decodeFrameSize();
-    if (frameSize == UNKNOWN_FRAME_SIZE || totalSize < frameSize) {
+    if (frameSize == UNKNOWN_FRAME_SIZE) {
       return null;
     }
 
-    // Reset size for next frame.
-    nextFrameSize = UNKNOWN_FRAME_SIZE;
-
-    Preconditions.checkArgument(frameSize < MAX_FRAME_SIZE, "Too large frame: 
%s", frameSize);
-    Preconditions.checkArgument(frameSize > 0, "Frame length should be 
positive: %s", frameSize);
+    if (frameBuf == null) {
+      Preconditions.checkArgument(frameSize < MAX_FRAME_SIZE,
+          "Too large frame: %s", frameSize);
+      Preconditions.checkArgument(frameSize > 0,
+          "Frame length should be positive: %s", frameSize);
+      frameRemainingBytes = (int) frameSize;
 
-    // If the first buffer holds the entire frame, return it.
-    int remaining = (int) frameSize;
-    if (buffers.getFirst().readableBytes() >= remaining) {
-      return nextBufferForFrame(remaining);
+      // If buffers is empty, then return immediately for more input data.
+      if (buffers.isEmpty()) {
+        return null;
+      }
+      // Otherwise, if the first buffer holds the entire frame, we attempt to
+      // build frame with it and return.
+      if (buffers.getFirst().readableBytes() >= frameRemainingBytes) {
+        // Reset buf and size for next frame.
+        frameBuf = null;
+        nextFrameSize = UNKNOWN_FRAME_SIZE;
+        return nextBufferForFrame(frameRemainingBytes);
+      }
+      // Other cases, create a composite buffer to manage all the buffers.
+      frameBuf = buffers.getFirst().alloc().compositeBuffer(Integer.MAX_VALUE);
     }
 
-    // Otherwise, create a composite buffer.
-    CompositeByteBuf frame = 
buffers.getFirst().alloc().compositeBuffer(Integer.MAX_VALUE);
-    while (remaining > 0) {
-      ByteBuf next = nextBufferForFrame(remaining);
-      remaining -= next.readableBytes();
-      frame.addComponent(next).writerIndex(frame.writerIndex() + 
next.readableBytes());
+    while (frameRemainingBytes > 0 && !buffers.isEmpty()) {
+      ByteBuf next = nextBufferForFrame(frameRemainingBytes);
+      frameRemainingBytes -= next.readableBytes();
+      frameBuf.addComponent(true, next);
     }
-    assert remaining == 0;
+    // If the delta size of frameBuf exceeds the threshold, then we do 
consolidation
+    // to reduce memory consumption.
+    if (frameBuf.capacity() - consolidatedFrameBufSize > consolidateThreshold) 
{
+      int newNumComponents = frameBuf.numComponents() - 
consolidatedNumComponents;
+      frameBuf.consolidate(consolidatedNumComponents, newNumComponents);
+      consolidatedFrameBufSize = frameBuf.capacity();
+      consolidatedNumComponents = frameBuf.numComponents();
+    }
+    if (frameRemainingBytes > 0) {
+      return null;
+    }
+
+    // Reset buf and size for next frame.
+    ByteBuf frame = frameBuf;
+    frameBuf = null;
+    consolidatedFrameBufSize = 0;
+    consolidatedNumComponents = 0;
+    nextFrameSize = UNKNOWN_FRAME_SIZE;
     return frame;
   }
 
diff --git 
a/common/network-common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java
 
b/common/network-common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java
index 7d40387..4b67aa8 100644
--- 
a/common/network-common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java
+++ 
b/common/network-common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java
@@ -27,11 +27,15 @@ import io.netty.buffer.Unpooled;
 import io.netty.channel.ChannelHandlerContext;
 import org.junit.AfterClass;
 import org.junit.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
 import static org.junit.Assert.*;
 import static org.mockito.Mockito.*;
 
 public class TransportFrameDecoderSuite {
 
+  private static final Logger logger = 
LoggerFactory.getLogger(TransportFrameDecoderSuite.class);
   private static Random RND = new Random();
 
   @AfterClass
@@ -48,6 +52,69 @@ public class TransportFrameDecoderSuite {
   }
 
   @Test
+  public void testConsolidationPerf() throws Exception {
+    long[] testingConsolidateThresholds = new long[] {
+        ByteUnit.MiB.toBytes(1),
+        ByteUnit.MiB.toBytes(5),
+        ByteUnit.MiB.toBytes(10),
+        ByteUnit.MiB.toBytes(20),
+        ByteUnit.MiB.toBytes(30),
+        ByteUnit.MiB.toBytes(50),
+        ByteUnit.MiB.toBytes(80),
+        ByteUnit.MiB.toBytes(100),
+        ByteUnit.MiB.toBytes(300),
+        ByteUnit.MiB.toBytes(500),
+        Long.MAX_VALUE };
+    for (long threshold : testingConsolidateThresholds) {
+      TransportFrameDecoder decoder = new TransportFrameDecoder(threshold);
+      ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
+      List<ByteBuf> retained = new ArrayList<>();
+      when(ctx.fireChannelRead(any())).thenAnswer(in -> {
+        ByteBuf buf = (ByteBuf) in.getArguments()[0];
+        retained.add(buf);
+        return null;
+      });
+
+      // Testing multiple messages
+      int numMessages = 3;
+      long targetBytes = ByteUnit.MiB.toBytes(300);
+      int pieceBytes = (int) ByteUnit.KiB.toBytes(32);
+      for (int i = 0; i < numMessages; i++) {
+        try {
+          long writtenBytes = 0;
+          long totalTime = 0;
+          ByteBuf buf = Unpooled.buffer(8);
+          buf.writeLong(8 + targetBytes);
+          decoder.channelRead(ctx, buf);
+          while (writtenBytes < targetBytes) {
+            buf = Unpooled.buffer(pieceBytes * 2);
+            ByteBuf writtenBuf = 
Unpooled.buffer(pieceBytes).writerIndex(pieceBytes);
+            buf.writeBytes(writtenBuf);
+            writtenBuf.release();
+            long start = System.currentTimeMillis();
+            decoder.channelRead(ctx, buf);
+            long elapsedTime = System.currentTimeMillis() - start;
+            totalTime += elapsedTime;
+            writtenBytes += pieceBytes;
+          }
+          logger.info("Writing 300MiB frame buf with consolidation of 
threshold " + threshold
+              + " took " + totalTime + " milis");
+        } finally {
+          for (ByteBuf buf : retained) {
+            release(buf);
+          }
+        }
+      }
+      long totalBytesGot = 0;
+      for (ByteBuf buf : retained) {
+        totalBytesGot += buf.capacity();
+      }
+      assertEquals(numMessages, retained.size());
+      assertEquals(targetBytes * numMessages, totalBytesGot);
+    }
+  }
+
+  @Test
   public void testInterception() throws Exception {
     int interceptedReads = 3;
     TransportFrameDecoder decoder = new TransportFrameDecoder();


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to