This is an automated email from the ASF dual-hosted git repository.

ggregory pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/commons-io.git


The following commit(s) were added to refs/heads/master by this push:
     new 7815650ad [IO-871] IOUtils.contentEquals is incorrect when 
InputStream.available under-reports
7815650ad is described below

commit 7815650ada4f27dc72c8c470d626481282b13dab
Author: Gary D. Gregory <garydgreg...@gmail.com>
AuthorDate: Fri Mar 21 18:39:31 2025 -0400

    [IO-871] IOUtils.contentEquals is incorrect when InputStream.available
    under-reports
---
 src/changes/changes.xml                            |  1 +
 .../apache/commons/io/channels/FileChannels.java   | 77 ++++++++++++++++------
 .../java/org/apache/commons/io/IOUtilsTest.java    | 59 +++++++++++++++++
 .../commons/io/channels/FileChannelsTest.java      | 71 +++++++++++++++++++-
 4 files changed, 186 insertions(+), 22 deletions(-)

diff --git a/src/changes/changes.xml b/src/changes/changes.xml
index 1fbaed6e5..ab0d30b8f 100644
--- a/src/changes/changes.xml
+++ b/src/changes/changes.xml
@@ -66,6 +66,7 @@ The <action> type attribute can be add,update,fix,remove.
       <action dev="ggregory" type="fix"                due-to="Gary 
Gregory">Deprecate FileChannels.contentEquals(FileChannel, FileChannel, int) in 
favor of FileChannels.contentEquals(SeekableByteChannel, SeekableByteChannel, 
int).</action>
       <action dev="ggregory" type="fix"                due-to="Gary 
Gregory">Improve performance of IOUtils.contentEquals(InputStream, InputStream) 
by about 13%.</action>
       <action dev="ggregory" type="fix" issue="IO-870" due-to="Gary 
Gregory">PathUtils.copyFileToDirectory() across file systems #728.</action>
+      <action dev="ggregory" type="fix" issue="IO-871" due-to="Éamonn McManus, 
Gary Gregory">IOUtils.contentEquals is incorrect when InputStream.available 
under-reports.</action>
       <!-- ADD -->
       <action dev="ggregory" type="add" issue="IO-860" due-to="Nico Strecker, 
Gary Gregory">Add ThrottledInputStream.Builder.setMaxBytes(long, 
ChronoUnit).</action>
       <action dev="ggregory" type="add"                due-to="Gary 
Gregory">Add IOIterable.</action>
diff --git a/src/main/java/org/apache/commons/io/channels/FileChannels.java 
b/src/main/java/org/apache/commons/io/channels/FileChannels.java
index 4f9eb60b4..e9a63ae70 100644
--- a/src/main/java/org/apache/commons/io/channels/FileChannels.java
+++ b/src/main/java/org/apache/commons/io/channels/FileChannels.java
@@ -60,43 +60,78 @@ public static boolean contentEquals(final FileChannel 
channel1, final FileChanne
      */
     public static boolean contentEquals(final ReadableByteChannel channel1, 
final ReadableByteChannel channel2, final int bufferCapacity) throws 
IOException {
         // Before making any changes, please test with 
org.apache.commons.io.jmh.IOUtilsContentEqualsInputStreamsBenchmark
+        if (bufferCapacity <= 0) {
+            throw new IllegalArgumentException();
+        }
         // Short-circuit test
         if (Objects.equals(channel1, channel2)) {
             return true;
         }
         // Dig in and do the work
-        final ByteBuffer byteBuffer1 = 
ByteBuffer.allocateDirect(bufferCapacity);
-        final ByteBuffer byteBuffer2 = 
ByteBuffer.allocateDirect(bufferCapacity);
-        int numRead1 = 0;
-        int numRead2 = 0;
-        boolean read0On1 = false;
-        boolean read0On2 = false;
+        final ByteBuffer c1Buffer = ByteBuffer.allocateDirect(bufferCapacity);
+        final ByteBuffer c2Buffer = ByteBuffer.allocateDirect(bufferCapacity);
+        int c1ReadNum = -1;
+        int c2ReadNum = -1;
+        boolean c1Read = true;
+        boolean c2Read = true;
+        boolean equals = false;
         // If a channel is a non-blocking channel, it may return 0 bytes read 
for any given call.
         while (true) {
-            if (!read0On2) {
-                numRead1 = channel1.read(byteBuffer1);
-                byteBuffer1.clear();
-                read0On1 = numRead1 == 0;
+            // don't call compact() in this method to avoid copying
+            if (c1Read) {
+                c1ReadNum = channel1.read(c1Buffer);
+                c1Buffer.position(0);
+                c1Read = c1ReadNum >= 0;
             }
-            if (!read0On1) {
-                numRead2 = channel2.read(byteBuffer2);
-                byteBuffer2.clear();
-                read0On2 = numRead2 == 0;
+            if (c2Read) {
+                c2ReadNum = channel2.read(c2Buffer);
+                c2Buffer.position(0);
+                c2Read = c2ReadNum >= 0;
             }
-            if (numRead1 == IOUtils.EOF && numRead2 == IOUtils.EOF) {
-                return byteBuffer1.equals(byteBuffer2);
+            if (c1ReadNum == IOUtils.EOF && c2ReadNum == IOUtils.EOF) {
+                return equals || c1Buffer.equals(c2Buffer);
             }
-            if (numRead1 == 0 || numRead2 == 0) {
-                // 0 may be returned from a non-blocking channel.
+            if (c1ReadNum == 0 || c2ReadNum == 0) {
                 Thread.yield();
+            }
+            if (c1ReadNum == 0 && c2ReadNum == IOUtils.EOF || c2ReadNum == 0 
&& c1ReadNum == IOUtils.EOF) {
                 continue;
             }
-            if (numRead1 != numRead2) {
-                return false;
+            if (c1ReadNum != c2ReadNum) {
+                final int limit = Math.min(c1ReadNum, c2ReadNum);
+                if (limit == IOUtils.EOF) {
+                    return false;
+                }
+                c1Buffer.limit(limit);
+                c2Buffer.limit(limit);
+                if (!c1Buffer.equals(c2Buffer)) {
+                    return false;
+                }
+                equals = true;
+                c1Buffer.limit(bufferCapacity);
+                c2Buffer.limit(bufferCapacity);
+                c1Read = c2ReadNum > c1ReadNum;
+                c2Read = c1ReadNum > c2ReadNum;
+                if (c1Read) {
+                    c1Buffer.position(0);
+                } else {
+                    c1Buffer.position(limit);
+                    c2Buffer.limit(c1Buffer.remaining());
+                    c1ReadNum -= c2ReadNum;
+                }
+                if (c2Read) {
+                    c2Buffer.position(0);
+                } else {
+                    c2Buffer.position(limit);
+                    c1Buffer.limit(c2Buffer.remaining());
+                    c2ReadNum -= c1ReadNum;
+                }
+                continue;
             }
-            if (!byteBuffer1.equals(byteBuffer2)) {
+            if (!c1Buffer.equals(c2Buffer)) {
                 return false;
             }
+            equals = c1Read = c2Read = true;
         }
     }
 
diff --git a/src/test/java/org/apache/commons/io/IOUtilsTest.java 
b/src/test/java/org/apache/commons/io/IOUtilsTest.java
index 9be837d4f..ab30742c8 100644
--- a/src/test/java/org/apache/commons/io/IOUtilsTest.java
+++ b/src/test/java/org/apache/commons/io/IOUtilsTest.java
@@ -45,6 +45,7 @@
 import java.io.InputStreamReader;
 import java.io.OutputStream;
 import java.io.Reader;
+import java.io.SequenceInputStream;
 import java.io.StringReader;
 import java.io.Writer;
 import java.net.ServerSocket;
@@ -689,6 +690,64 @@ public void testContentEqualsIgnoreEOL() throws Exception {
         testSingleEOL("1235", "1234", false);
     }
 
+    @Test
+    public void testContentEqualsSequenceInputStream() throws Exception {
+        // not equals
+        // @formatter:off
+        assertFalse(IOUtils.contentEquals(
+                new ByteArrayInputStream("ab".getBytes()),
+                new SequenceInputStream(
+                    new ByteArrayInputStream("a".getBytes()),
+                    new ByteArrayInputStream("b-".getBytes()))));
+        assertFalse(IOUtils.contentEquals(
+                new ByteArrayInputStream("ab".getBytes()),
+                new SequenceInputStream(
+                    new ByteArrayInputStream("a-".getBytes()),
+                    new ByteArrayInputStream("b".getBytes()))));
+        assertFalse(IOUtils.contentEquals(
+                new ByteArrayInputStream("ab-".getBytes()),
+                new SequenceInputStream(
+                    new ByteArrayInputStream("a".getBytes()),
+                    new ByteArrayInputStream("b".getBytes()))));
+        assertFalse(IOUtils.contentEquals(
+                new ByteArrayInputStream("".getBytes()),
+                new SequenceInputStream(
+                    new ByteArrayInputStream("a".getBytes()),
+                    new ByteArrayInputStream("b".getBytes()))));
+        assertFalse(IOUtils.contentEquals(
+                new ByteArrayInputStream("".getBytes()),
+                new SequenceInputStream(
+                    new ByteArrayInputStream("".getBytes()),
+                    new ByteArrayInputStream("b".getBytes()))));
+        assertFalse(IOUtils.contentEquals(
+                new ByteArrayInputStream("ab".getBytes()),
+                new SequenceInputStream(
+                    new ByteArrayInputStream("".getBytes()),
+                    new ByteArrayInputStream("".getBytes()))));
+        // equals
+        assertTrue(IOUtils.contentEquals(
+                new ByteArrayInputStream("".getBytes()),
+                new SequenceInputStream(
+                    new ByteArrayInputStream("".getBytes()),
+                    new ByteArrayInputStream("".getBytes()))));
+        assertTrue(IOUtils.contentEquals(
+                new ByteArrayInputStream("ab".getBytes()),
+                new SequenceInputStream(
+                    new ByteArrayInputStream("a".getBytes()),
+                    new ByteArrayInputStream("b".getBytes()))));
+        assertTrue(IOUtils.contentEquals(
+                new ByteArrayInputStream("ab".getBytes()),
+                new SequenceInputStream(
+                    new ByteArrayInputStream("ab".getBytes()),
+                    new ByteArrayInputStream("".getBytes()))));
+        assertTrue(IOUtils.contentEquals(
+                new ByteArrayInputStream("ab".getBytes()),
+                new SequenceInputStream(
+                    new ByteArrayInputStream("".getBytes()),
+                    new ByteArrayInputStream("ab".getBytes()))));
+        // @formatter:on
+    }
+
     @Test
     public void testCopy_ByteArray_OutputStream() throws Exception {
         final File destination = TestUtils.newFile(temporaryFolder, 
"copy8.txt");
diff --git a/src/test/java/org/apache/commons/io/channels/FileChannelsTest.java 
b/src/test/java/org/apache/commons/io/channels/FileChannelsTest.java
index 46f223d16..99a4227fc 100644
--- a/src/test/java/org/apache/commons/io/channels/FileChannelsTest.java
+++ b/src/test/java/org/apache/commons/io/channels/FileChannelsTest.java
@@ -22,9 +22,13 @@
 import static org.junit.jupiter.api.Assertions.assertNotEquals;
 import static org.junit.jupiter.api.Assertions.assertTrue;
 
+import java.io.ByteArrayInputStream;
 import java.io.File;
 import java.io.FileInputStream;
 import java.io.IOException;
+import java.io.InputStream;
+import java.io.SequenceInputStream;
+import java.nio.channels.Channels;
 import java.nio.channels.FileChannel;
 import java.nio.channels.SeekableByteChannel;
 import java.nio.file.Files;
@@ -35,6 +39,8 @@
 import org.apache.commons.io.file.AbstractTempDirTest;
 import org.apache.commons.lang3.ArrayUtils;
 import org.apache.commons.lang3.StringUtils;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.ValueSource;
 import org.junitpioneer.jupiter.cartesian.CartesianTest;
 import org.junitpioneer.jupiter.cartesian.CartesianTest.Values;
 
@@ -47,7 +53,7 @@ enum FileChannelType {
         STOCK, PROXY, NON_BLOCKING, FIXED_READ_SIZE
     }
 
-    private static final int LARGE_FILE_SIZE = 
Integer.getInteger(FileChannelsTest.class.getSimpleName(), 100000);
+    private static final int LARGE_FILE_SIZE = 
Integer.getInteger(FileChannelsTest.class.getSimpleName(), 100_000);
 
     private static final int SMALL_BUFFER_SIZE = 1024;
     private static final String CONTENT = StringUtils.repeat("x", 
SMALL_BUFFER_SIZE);
@@ -93,6 +99,10 @@ private static FileChannel wrap(final FileChannel fc, final 
FileChannelType file
         }
     }
 
+    private boolean contentEquals(final InputStream in1, final InputStream 
in2, final int bufferCapacity) throws IOException {
+        return FileChannels.contentEquals(Channels.newChannel(in1), 
Channels.newChannel(in2), bufferCapacity);
+    }
+
     private void testContentEquals(final String content1, final String 
content2, final int bufferSize, final FileChannelType fileChannelType)
             throws IOException {
         assertTrue(FileChannels.contentEquals(null, null, bufferSize));
@@ -287,4 +297,63 @@ public void testContentEqualsSeekableByteChannel(
             Files.deleteIfExists(bigFile3);
         }
     }
+
+    @ParameterizedTest
+    @ValueSource(ints = { 1, 2, 4, 8, 16, 1024, 4096, 8192 })
+    public void testContentEqualsSequenceInputStream(final int bufferCapacity) 
throws Exception {
+        // not equals
+        // @formatter:off
+        assertFalse(contentEquals(
+                new ByteArrayInputStream("ab".getBytes()),
+                new SequenceInputStream(
+                    new ByteArrayInputStream("a".getBytes()),
+                    new ByteArrayInputStream("b-".getBytes())), 
bufferCapacity));
+        assertFalse(contentEquals(
+                new ByteArrayInputStream("ab".getBytes()),
+                new SequenceInputStream(
+                    new ByteArrayInputStream("a-".getBytes()),
+                    new ByteArrayInputStream("b".getBytes())), 
bufferCapacity));
+        assertFalse(contentEquals(
+                new ByteArrayInputStream("ab-".getBytes()),
+                new SequenceInputStream(
+                    new ByteArrayInputStream("a".getBytes()),
+                    new ByteArrayInputStream("b".getBytes())), 
bufferCapacity));
+        assertFalse(contentEquals(
+                new ByteArrayInputStream("".getBytes()),
+                new SequenceInputStream(
+                    new ByteArrayInputStream("a".getBytes()),
+                    new ByteArrayInputStream("b".getBytes())), 
bufferCapacity));
+        assertFalse(contentEquals(
+                new ByteArrayInputStream("".getBytes()),
+                new SequenceInputStream(
+                    new ByteArrayInputStream("".getBytes()),
+                    new ByteArrayInputStream("b".getBytes())), 
bufferCapacity));
+        assertFalse(contentEquals(
+                new ByteArrayInputStream("ab".getBytes()),
+                new SequenceInputStream(
+                    new ByteArrayInputStream("".getBytes()),
+                    new ByteArrayInputStream("".getBytes())), bufferCapacity));
+        // equals
+        assertTrue(contentEquals(
+                new ByteArrayInputStream("".getBytes()),
+                new SequenceInputStream(
+                    new ByteArrayInputStream("".getBytes()),
+                    new ByteArrayInputStream("".getBytes())), bufferCapacity));
+        assertTrue(contentEquals(
+                new ByteArrayInputStream("ab".getBytes()),
+                new SequenceInputStream(
+                    new ByteArrayInputStream("a".getBytes()),
+                    new ByteArrayInputStream("b".getBytes())), 
bufferCapacity));
+        assertTrue(contentEquals(
+                new ByteArrayInputStream("ab".getBytes()),
+                new SequenceInputStream(
+                    new ByteArrayInputStream("ab".getBytes()),
+                    new ByteArrayInputStream("".getBytes())), bufferCapacity));
+        assertTrue(contentEquals(
+                new ByteArrayInputStream("ab".getBytes()),
+                new SequenceInputStream(
+                    new ByteArrayInputStream("".getBytes()),
+                    new ByteArrayInputStream("ab".getBytes())), 
bufferCapacity));
+        // @formatter:on
+    }
 }

Reply via email to