Repository: spark
Updated Branches:
  refs/heads/master 38cf8f2a5 -> 67e085ef6


[SPARK-16420] Ensure compression streams are closed.

## What changes were proposed in this pull request?

This uses the try/finally pattern to ensure streams are closed after use. 
`UnsafeShuffleWriter` wasn't closing compression streams, causing them to leak 
resources until garbage collected. This was causing a problem with codecs that 
use off-heap memory.

## How was this patch tested?

Current tests are sufficient. This should not change behavior.

Author: Ryan Blue <[email protected]>

Closes #14093 from rdblue/SPARK-16420-unsafe-shuffle-writer-leak.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/67e085ef
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/67e085ef
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/67e085ef

Branch: refs/heads/master
Commit: 67e085ef6dd62774095f3187844c091db1a6a72c
Parents: 38cf8f2
Author: Ryan Blue <[email protected]>
Authored: Fri Jul 8 12:37:26 2016 -0700
Committer: Reynold Xin <[email protected]>
Committed: Fri Jul 8 12:37:26 2016 -0700

----------------------------------------------------------------------
 .../spark/network/util/LimitedInputStream.java  | 23 ++++++++++++++++++++
 .../spark/shuffle/sort/UnsafeShuffleWriter.java | 17 ++++++++++-----
 .../spark/broadcast/TorrentBroadcast.scala      | 13 ++++++++---
 .../serializer/GenericAvroSerializer.scala      | 15 ++++++++++---
 4 files changed, 57 insertions(+), 11 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/67e085ef/common/network-common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java
----------------------------------------------------------------------
diff --git 
a/common/network-common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java
 
b/common/network-common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java
index 922c37a..e79eef0 100644
--- 
a/common/network-common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java
+++ 
b/common/network-common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java
@@ -48,11 +48,27 @@ import com.google.common.base.Preconditions;
  * use this functionality in both a Guava 11 environment and a Guava &gt;14 
environment.
  */
 public final class LimitedInputStream extends FilterInputStream {
+  private final boolean closeWrappedStream;
   private long left;
   private long mark = -1;
 
   public LimitedInputStream(InputStream in, long limit) {
+    this(in, limit, true);
+  }
+
+  /**
+   * Create a LimitedInputStream that will read {@code limit} bytes from 
{@code in}.
+   * <p>
+   * If {@code closeWrappedStream} is true, this will close {@code in} when it 
is closed.
+   * Otherwise, the stream is left open for reading its remaining content.
+   *
+   * @param in a {@link InputStream} to read from
+   * @param limit the number of bytes to read
+   * @param closeWrappedStream whether to close {@code in} when {@link #close} 
is called
+     */
+  public LimitedInputStream(InputStream in, long limit, boolean 
closeWrappedStream) {
     super(in);
+    this.closeWrappedStream = closeWrappedStream;
     Preconditions.checkNotNull(in);
     Preconditions.checkArgument(limit >= 0, "limit must be non-negative");
     left = limit;
@@ -102,4 +118,11 @@ public final class LimitedInputStream extends 
FilterInputStream {
     left -= skipped;
     return skipped;
   }
+
+  @Override
+  public void close() throws IOException {
+    if (closeWrappedStream) {
+      super.close();
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/67e085ef/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
----------------------------------------------------------------------
diff --git 
a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java 
b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
index 05fa04c..08fb887 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
@@ -349,12 +349,19 @@ public class UnsafeShuffleWriter<K, V> extends 
ShuffleWriter<K, V> {
         for (int i = 0; i < spills.length; i++) {
           final long partitionLengthInSpill = 
spills[i].partitionLengths[partition];
           if (partitionLengthInSpill > 0) {
-            InputStream partitionInputStream =
-              new LimitedInputStream(spillInputStreams[i], 
partitionLengthInSpill);
-            if (compressionCodec != null) {
-              partitionInputStream = 
compressionCodec.compressedInputStream(partitionInputStream);
+            InputStream partitionInputStream = null;
+            boolean innerThrewException = true;
+            try {
+              partitionInputStream =
+                  new LimitedInputStream(spillInputStreams[i], 
partitionLengthInSpill, false);
+              if (compressionCodec != null) {
+                partitionInputStream = 
compressionCodec.compressedInputStream(partitionInputStream);
+              }
+              ByteStreams.copy(partitionInputStream, mergedFileOutputStream);
+              innerThrewException = false;
+            } finally {
+              Closeables.close(partitionInputStream, innerThrewException);
             }
-            ByteStreams.copy(partitionInputStream, mergedFileOutputStream);
           }
         }
         mergedFileOutputStream.flush();

http://git-wip-us.apache.org/repos/asf/spark/blob/67e085ef/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala 
b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
index 632b0ae..e8d6d58 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
@@ -232,7 +232,11 @@ private object TorrentBroadcast extends Logging {
     val out = compressionCodec.map(c => 
c.compressedOutputStream(cbbos)).getOrElse(cbbos)
     val ser = serializer.newInstance()
     val serOut = ser.serializeStream(out)
-    serOut.writeObject[T](obj).close()
+    Utils.tryWithSafeFinally {
+      serOut.writeObject[T](obj)
+    } {
+      serOut.close()
+    }
     cbbos.toChunkedByteBuffer.getChunks()
   }
 
@@ -246,8 +250,11 @@ private object TorrentBroadcast extends Logging {
     val in: InputStream = compressionCodec.map(c => 
c.compressedInputStream(is)).getOrElse(is)
     val ser = serializer.newInstance()
     val serIn = ser.deserializeStream(in)
-    val obj = serIn.readObject[T]()
-    serIn.close()
+    val obj = Utils.tryWithSafeFinally {
+      serIn.readObject[T]()
+    } {
+      serIn.close()
+    }
     obj
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/67e085ef/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala 
b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala
index d17a789..f0ed41f 100644
--- 
a/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala
+++ 
b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala
@@ -32,6 +32,7 @@ import org.apache.commons.io.IOUtils
 
 import org.apache.spark.{SparkEnv, SparkException}
 import org.apache.spark.io.CompressionCodec
+import org.apache.spark.util.Utils
 
 /**
  * Custom serializer used for generic Avro records. If the user registers the 
schemas
@@ -72,8 +73,11 @@ private[serializer] class GenericAvroSerializer(schemas: 
Map[Long, String])
   def compress(schema: Schema): Array[Byte] = 
compressCache.getOrElseUpdate(schema, {
     val bos = new ByteArrayOutputStream()
     val out = codec.compressedOutputStream(bos)
-    out.write(schema.toString.getBytes(StandardCharsets.UTF_8))
-    out.close()
+    Utils.tryWithSafeFinally {
+      out.write(schema.toString.getBytes(StandardCharsets.UTF_8))
+    } {
+      out.close()
+    }
     bos.toByteArray
   })
 
@@ -86,7 +90,12 @@ private[serializer] class GenericAvroSerializer(schemas: 
Map[Long, String])
       schemaBytes.array(),
       schemaBytes.arrayOffset() + schemaBytes.position(),
       schemaBytes.remaining())
-    val bytes = IOUtils.toByteArray(codec.compressedInputStream(bis))
+    val in = codec.compressedInputStream(bis)
+    val bytes = Utils.tryWithSafeFinally {
+      IOUtils.toByteArray(in)
+    } {
+      in.close()
+    }
     new Schema.Parser().parse(new String(bytes, StandardCharsets.UTF_8))
   })
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to