This is an automated email from the ASF dual-hosted git repository. ggregory pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/commons-io.git
The following commit(s) were added to refs/heads/master by this push: new 7815650ad [IO-871] IOUtils.contentEquals is incorrect when InputStream.available under-reports 7815650ad is described below commit 7815650ada4f27dc72c8c470d626481282b13dab Author: Gary D. Gregory <garydgreg...@gmail.com> AuthorDate: Fri Mar 21 18:39:31 2025 -0400 [IO-871] IOUtils.contentEquals is incorrect when InputStream.available under-reports --- src/changes/changes.xml | 1 + .../apache/commons/io/channels/FileChannels.java | 77 ++++++++++++++++------ .../java/org/apache/commons/io/IOUtilsTest.java | 59 +++++++++++++++++ .../commons/io/channels/FileChannelsTest.java | 71 +++++++++++++++++++- 4 files changed, 186 insertions(+), 22 deletions(-) diff --git a/src/changes/changes.xml b/src/changes/changes.xml index 1fbaed6e5..ab0d30b8f 100644 --- a/src/changes/changes.xml +++ b/src/changes/changes.xml @@ -66,6 +66,7 @@ The <action> type attribute can be add,update,fix,remove. <action dev="ggregory" type="fix" due-to="Gary Gregory">Deprecate FileChannels.contentEquals(FileChannel, FileChannel, int) in favor of FileChannels.contentEquals(SeekableByteChannel, SeekableByteChannel, int).</action> <action dev="ggregory" type="fix" due-to="Gary Gregory">Improve performance of IOUtils.contentEquals(InputStream, InputStream) by about 13%.</action> <action dev="ggregory" type="fix" issue="IO-870" due-to="Gary Gregory">PathUtils.copyFileToDirectory() across file systems #728.</action> + <action dev="ggregory" type="fix" issue="IO-871" due-to="Éamonn McManus, Gary Gregory">IOUtils.contentEquals is incorrect when InputStream.available under-reports.</action> <!-- ADD --> <action dev="ggregory" type="add" issue="IO-860" due-to="Nico Strecker, Gary Gregory">Add ThrottledInputStream.Builder.setMaxBytes(long, ChronoUnit).</action> <action dev="ggregory" type="add" due-to="Gary Gregory">Add IOIterable.</action> diff --git a/src/main/java/org/apache/commons/io/channels/FileChannels.java b/src/main/java/org/apache/commons/io/channels/FileChannels.java index 4f9eb60b4..e9a63ae70 100644 --- a/src/main/java/org/apache/commons/io/channels/FileChannels.java +++ b/src/main/java/org/apache/commons/io/channels/FileChannels.java @@ -60,43 +60,78 @@ public static boolean contentEquals(final FileChannel channel1, final FileChanne */ public static boolean contentEquals(final ReadableByteChannel channel1, final ReadableByteChannel channel2, final int bufferCapacity) throws IOException { // Before making any changes, please test with org.apache.commons.io.jmh.IOUtilsContentEqualsInputStreamsBenchmark + if (bufferCapacity <= 0) { + throw new IllegalArgumentException(); + } // Short-circuit test if (Objects.equals(channel1, channel2)) { return true; } // Dig in and do the work - final ByteBuffer byteBuffer1 = ByteBuffer.allocateDirect(bufferCapacity); - final ByteBuffer byteBuffer2 = ByteBuffer.allocateDirect(bufferCapacity); - int numRead1 = 0; - int numRead2 = 0; - boolean read0On1 = false; - boolean read0On2 = false; + final ByteBuffer c1Buffer = ByteBuffer.allocateDirect(bufferCapacity); + final ByteBuffer c2Buffer = ByteBuffer.allocateDirect(bufferCapacity); + int c1ReadNum = -1; + int c2ReadNum = -1; + boolean c1Read = true; + boolean c2Read = true; + boolean equals = false; // If a channel is a non-blocking channel, it may return 0 bytes read for any given call. while (true) { - if (!read0On2) { - numRead1 = channel1.read(byteBuffer1); - byteBuffer1.clear(); - read0On1 = numRead1 == 0; + // don't call compact() in this method to avoid copying + if (c1Read) { + c1ReadNum = channel1.read(c1Buffer); + c1Buffer.position(0); + c1Read = c1ReadNum >= 0; } - if (!read0On1) { - numRead2 = channel2.read(byteBuffer2); - byteBuffer2.clear(); - read0On2 = numRead2 == 0; + if (c2Read) { + c2ReadNum = channel2.read(c2Buffer); + c2Buffer.position(0); + c2Read = c2ReadNum >= 0; } - if (numRead1 == IOUtils.EOF && numRead2 == IOUtils.EOF) { - return byteBuffer1.equals(byteBuffer2); + if (c1ReadNum == IOUtils.EOF && c2ReadNum == IOUtils.EOF) { + return equals || c1Buffer.equals(c2Buffer); } - if (numRead1 == 0 || numRead2 == 0) { - // 0 may be returned from a non-blocking channel. + if (c1ReadNum == 0 || c2ReadNum == 0) { Thread.yield(); + } + if (c1ReadNum == 0 && c2ReadNum == IOUtils.EOF || c2ReadNum == 0 && c1ReadNum == IOUtils.EOF) { continue; } - if (numRead1 != numRead2) { - return false; + if (c1ReadNum != c2ReadNum) { + final int limit = Math.min(c1ReadNum, c2ReadNum); + if (limit == IOUtils.EOF) { + return false; + } + c1Buffer.limit(limit); + c2Buffer.limit(limit); + if (!c1Buffer.equals(c2Buffer)) { + return false; + } + equals = true; + c1Buffer.limit(bufferCapacity); + c2Buffer.limit(bufferCapacity); + c1Read = c2ReadNum > c1ReadNum; + c2Read = c1ReadNum > c2ReadNum; + if (c1Read) { + c1Buffer.position(0); + } else { + c1Buffer.position(limit); + c2Buffer.limit(c1Buffer.remaining()); + c1ReadNum -= c2ReadNum; + } + if (c2Read) { + c2Buffer.position(0); + } else { + c2Buffer.position(limit); + c1Buffer.limit(c2Buffer.remaining()); + c2ReadNum -= c1ReadNum; + } + continue; } - if (!byteBuffer1.equals(byteBuffer2)) { + if (!c1Buffer.equals(c2Buffer)) { return false; } + equals = c1Read = c2Read = true; } } diff --git a/src/test/java/org/apache/commons/io/IOUtilsTest.java b/src/test/java/org/apache/commons/io/IOUtilsTest.java index 9be837d4f..ab30742c8 100644 --- a/src/test/java/org/apache/commons/io/IOUtilsTest.java +++ b/src/test/java/org/apache/commons/io/IOUtilsTest.java @@ -45,6 +45,7 @@ import java.io.InputStreamReader; import java.io.OutputStream; import java.io.Reader; +import java.io.SequenceInputStream; import java.io.StringReader; import java.io.Writer; import java.net.ServerSocket; @@ -689,6 +690,64 @@ public void testContentEqualsIgnoreEOL() throws Exception { testSingleEOL("1235", "1234", false); } + @Test + public void testContentEqualsSequenceInputStream() throws Exception { + // not equals + // @formatter:off + assertFalse(IOUtils.contentEquals( + new ByteArrayInputStream("ab".getBytes()), + new SequenceInputStream( + new ByteArrayInputStream("a".getBytes()), + new ByteArrayInputStream("b-".getBytes())))); + assertFalse(IOUtils.contentEquals( + new ByteArrayInputStream("ab".getBytes()), + new SequenceInputStream( + new ByteArrayInputStream("a-".getBytes()), + new ByteArrayInputStream("b".getBytes())))); + assertFalse(IOUtils.contentEquals( + new ByteArrayInputStream("ab-".getBytes()), + new SequenceInputStream( + new ByteArrayInputStream("a".getBytes()), + new ByteArrayInputStream("b".getBytes())))); + assertFalse(IOUtils.contentEquals( + new ByteArrayInputStream("".getBytes()), + new SequenceInputStream( + new ByteArrayInputStream("a".getBytes()), + new ByteArrayInputStream("b".getBytes())))); + assertFalse(IOUtils.contentEquals( + new ByteArrayInputStream("".getBytes()), + new SequenceInputStream( + new ByteArrayInputStream("".getBytes()), + new ByteArrayInputStream("b".getBytes())))); + assertFalse(IOUtils.contentEquals( + new ByteArrayInputStream("ab".getBytes()), + new SequenceInputStream( + new ByteArrayInputStream("".getBytes()), + new ByteArrayInputStream("".getBytes())))); + // equals + assertTrue(IOUtils.contentEquals( + new ByteArrayInputStream("".getBytes()), + new SequenceInputStream( + new ByteArrayInputStream("".getBytes()), + new ByteArrayInputStream("".getBytes())))); + assertTrue(IOUtils.contentEquals( + new ByteArrayInputStream("ab".getBytes()), + new SequenceInputStream( + new ByteArrayInputStream("a".getBytes()), + new ByteArrayInputStream("b".getBytes())))); + assertTrue(IOUtils.contentEquals( + new ByteArrayInputStream("ab".getBytes()), + new SequenceInputStream( + new ByteArrayInputStream("ab".getBytes()), + new ByteArrayInputStream("".getBytes())))); + assertTrue(IOUtils.contentEquals( + new ByteArrayInputStream("ab".getBytes()), + new SequenceInputStream( + new ByteArrayInputStream("".getBytes()), + new ByteArrayInputStream("ab".getBytes())))); + // @formatter:on + } + @Test public void testCopy_ByteArray_OutputStream() throws Exception { final File destination = TestUtils.newFile(temporaryFolder, "copy8.txt"); diff --git a/src/test/java/org/apache/commons/io/channels/FileChannelsTest.java b/src/test/java/org/apache/commons/io/channels/FileChannelsTest.java index 46f223d16..99a4227fc 100644 --- a/src/test/java/org/apache/commons/io/channels/FileChannelsTest.java +++ b/src/test/java/org/apache/commons/io/channels/FileChannelsTest.java @@ -22,9 +22,13 @@ import static org.junit.jupiter.api.Assertions.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertTrue; +import java.io.ByteArrayInputStream; import java.io.File; import java.io.FileInputStream; import java.io.IOException; +import java.io.InputStream; +import java.io.SequenceInputStream; +import java.nio.channels.Channels; import java.nio.channels.FileChannel; import java.nio.channels.SeekableByteChannel; import java.nio.file.Files; @@ -35,6 +39,8 @@ import org.apache.commons.io.file.AbstractTempDirTest; import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.StringUtils; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import org.junitpioneer.jupiter.cartesian.CartesianTest; import org.junitpioneer.jupiter.cartesian.CartesianTest.Values; @@ -47,7 +53,7 @@ enum FileChannelType { STOCK, PROXY, NON_BLOCKING, FIXED_READ_SIZE } - private static final int LARGE_FILE_SIZE = Integer.getInteger(FileChannelsTest.class.getSimpleName(), 100000); + private static final int LARGE_FILE_SIZE = Integer.getInteger(FileChannelsTest.class.getSimpleName(), 100_000); private static final int SMALL_BUFFER_SIZE = 1024; private static final String CONTENT = StringUtils.repeat("x", SMALL_BUFFER_SIZE); @@ -93,6 +99,10 @@ private static FileChannel wrap(final FileChannel fc, final FileChannelType file } } + private boolean contentEquals(final InputStream in1, final InputStream in2, final int bufferCapacity) throws IOException { + return FileChannels.contentEquals(Channels.newChannel(in1), Channels.newChannel(in2), bufferCapacity); + } + private void testContentEquals(final String content1, final String content2, final int bufferSize, final FileChannelType fileChannelType) throws IOException { assertTrue(FileChannels.contentEquals(null, null, bufferSize)); @@ -287,4 +297,63 @@ public void testContentEqualsSeekableByteChannel( Files.deleteIfExists(bigFile3); } } + + @ParameterizedTest + @ValueSource(ints = { 1, 2, 4, 8, 16, 1024, 4096, 8192 }) + public void testContentEqualsSequenceInputStream(final int bufferCapacity) throws Exception { + // not equals + // @formatter:off + assertFalse(contentEquals( + new ByteArrayInputStream("ab".getBytes()), + new SequenceInputStream( + new ByteArrayInputStream("a".getBytes()), + new ByteArrayInputStream("b-".getBytes())), bufferCapacity)); + assertFalse(contentEquals( + new ByteArrayInputStream("ab".getBytes()), + new SequenceInputStream( + new ByteArrayInputStream("a-".getBytes()), + new ByteArrayInputStream("b".getBytes())), bufferCapacity)); + assertFalse(contentEquals( + new ByteArrayInputStream("ab-".getBytes()), + new SequenceInputStream( + new ByteArrayInputStream("a".getBytes()), + new ByteArrayInputStream("b".getBytes())), bufferCapacity)); + assertFalse(contentEquals( + new ByteArrayInputStream("".getBytes()), + new SequenceInputStream( + new ByteArrayInputStream("a".getBytes()), + new ByteArrayInputStream("b".getBytes())), bufferCapacity)); + assertFalse(contentEquals( + new ByteArrayInputStream("".getBytes()), + new SequenceInputStream( + new ByteArrayInputStream("".getBytes()), + new ByteArrayInputStream("b".getBytes())), bufferCapacity)); + assertFalse(contentEquals( + new ByteArrayInputStream("ab".getBytes()), + new SequenceInputStream( + new ByteArrayInputStream("".getBytes()), + new ByteArrayInputStream("".getBytes())), bufferCapacity)); + // equals + assertTrue(contentEquals( + new ByteArrayInputStream("".getBytes()), + new SequenceInputStream( + new ByteArrayInputStream("".getBytes()), + new ByteArrayInputStream("".getBytes())), bufferCapacity)); + assertTrue(contentEquals( + new ByteArrayInputStream("ab".getBytes()), + new SequenceInputStream( + new ByteArrayInputStream("a".getBytes()), + new ByteArrayInputStream("b".getBytes())), bufferCapacity)); + assertTrue(contentEquals( + new ByteArrayInputStream("ab".getBytes()), + new SequenceInputStream( + new ByteArrayInputStream("ab".getBytes()), + new ByteArrayInputStream("".getBytes())), bufferCapacity)); + assertTrue(contentEquals( + new ByteArrayInputStream("ab".getBytes()), + new SequenceInputStream( + new ByteArrayInputStream("".getBytes()), + new ByteArrayInputStream("ab".getBytes())), bufferCapacity)); + // @formatter:on + } }