This is an automated email from the ASF dual-hosted git repository.

ggregory pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/commons-io.git


The following commit(s) were added to refs/heads/master by this push:
     new dde1d3928 [IO-871] IOUtils.contentEquals is incorrect when 
InputStream.available under-reports
dde1d3928 is described below

commit dde1d39286644be95dd6bf8cfdbae5f278a4be99
Author: Gary D. Gregory <garydgreg...@gmail.com>
AuthorDate: Sun Mar 23 04:10:57 2025 -0400

    [IO-871] IOUtils.contentEquals is incorrect when InputStream.available
    under-reports
---
 .../apache/commons/io/channels/FileChannels.java   | 101 ++++++++++-----------
 .../java/org/apache/commons/io/IOUtilsTest.java    |  20 ++++
 .../commons/io/channels/FileChannelsTest.java      |   2 +
 .../io/channels/FixedReadSizeFileChannelProxy.java |  32 ++++---
 4 files changed, 89 insertions(+), 66 deletions(-)

diff --git a/src/main/java/org/apache/commons/io/channels/FileChannels.java 
b/src/main/java/org/apache/commons/io/channels/FileChannels.java
index e9a63ae70..aca2850ba 100644
--- a/src/main/java/org/apache/commons/io/channels/FileChannels.java
+++ b/src/main/java/org/apache/commons/io/channels/FileChannels.java
@@ -60,9 +60,6 @@ public static boolean contentEquals(final FileChannel 
channel1, final FileChanne
      */
     public static boolean contentEquals(final ReadableByteChannel channel1, 
final ReadableByteChannel channel2, final int bufferCapacity) throws 
IOException {
         // Before making any changes, please test with 
org.apache.commons.io.jmh.IOUtilsContentEqualsInputStreamsBenchmark
-        if (bufferCapacity <= 0) {
-            throw new IllegalArgumentException();
-        }
         // Short-circuit test
         if (Objects.equals(channel1, channel2)) {
             return true;
@@ -70,68 +67,36 @@ public static boolean contentEquals(final 
ReadableByteChannel channel1, final Re
         // Dig in and do the work
         final ByteBuffer c1Buffer = ByteBuffer.allocateDirect(bufferCapacity);
         final ByteBuffer c2Buffer = ByteBuffer.allocateDirect(bufferCapacity);
-        int c1ReadNum = -1;
-        int c2ReadNum = -1;
-        boolean c1Read = true;
-        boolean c2Read = true;
-        boolean equals = false;
+        int c1NumRead = 0;
+        int c2NumRead = 0;
+        boolean c1Read0 = false;
+        boolean c2Read0 = false;
         // If a channel is a non-blocking channel, it may return 0 bytes read 
for any given call.
         while (true) {
-            // don't call compact() in this method to avoid copying
-            if (c1Read) {
-                c1ReadNum = channel1.read(c1Buffer);
-                c1Buffer.position(0);
-                c1Read = c1ReadNum >= 0;
+            if (!c2Read0) {
+                c1NumRead = readToLimit(channel1, c1Buffer);
+                c1Buffer.clear();
+                c1Read0 = c1NumRead == 0;
             }
-            if (c2Read) {
-                c2ReadNum = channel2.read(c2Buffer);
-                c2Buffer.position(0);
-                c2Read = c2ReadNum >= 0;
+            if (!c1Read0) {
+                c2NumRead = readToLimit(channel2, c2Buffer);
+                c2Buffer.clear();
+                c2Read0 = c2NumRead == 0;
             }
-            if (c1ReadNum == IOUtils.EOF && c2ReadNum == IOUtils.EOF) {
-                return equals || c1Buffer.equals(c2Buffer);
+            if (c1NumRead == IOUtils.EOF && c2NumRead == IOUtils.EOF) {
+                return c1Buffer.equals(c2Buffer);
             }
-            if (c1ReadNum == 0 || c2ReadNum == 0) {
+            if (c1NumRead == 0 || c2NumRead == 0) {
+                // 0 may be returned from a non-blocking channel.
                 Thread.yield();
-            }
-            if (c1ReadNum == 0 && c2ReadNum == IOUtils.EOF || c2ReadNum == 0 
&& c1ReadNum == IOUtils.EOF) {
                 continue;
             }
-            if (c1ReadNum != c2ReadNum) {
-                final int limit = Math.min(c1ReadNum, c2ReadNum);
-                if (limit == IOUtils.EOF) {
-                    return false;
-                }
-                c1Buffer.limit(limit);
-                c2Buffer.limit(limit);
-                if (!c1Buffer.equals(c2Buffer)) {
-                    return false;
-                }
-                equals = true;
-                c1Buffer.limit(bufferCapacity);
-                c2Buffer.limit(bufferCapacity);
-                c1Read = c2ReadNum > c1ReadNum;
-                c2Read = c1ReadNum > c2ReadNum;
-                if (c1Read) {
-                    c1Buffer.position(0);
-                } else {
-                    c1Buffer.position(limit);
-                    c2Buffer.limit(c1Buffer.remaining());
-                    c1ReadNum -= c2ReadNum;
-                }
-                if (c2Read) {
-                    c2Buffer.position(0);
-                } else {
-                    c2Buffer.position(limit);
-                    c1Buffer.limit(c2Buffer.remaining());
-                    c2ReadNum -= c1ReadNum;
-                }
-                continue;
+            if (c1NumRead != c2NumRead) {
+                return false;
             }
             if (!c1Buffer.equals(c2Buffer)) {
                 return false;
             }
-            equals = c1Read = c2Read = true;
         }
     }
 
@@ -162,6 +127,36 @@ public static boolean contentEquals(final 
SeekableByteChannel channel1, final Se
         return size1 == 0 && size2 == 0 || contentEquals((ReadableByteChannel) 
channel1, channel2, bufferCapacity);
     }
 
+    /**
+     * Reads a sequence of bytes from a channel into the given buffer until 
the buffer reaches its limit or the channel has reaches end-of-stream.
+     * <p>
+     * The buffer's limit is not changed.
+     * </p>
+     *
+     * @param channel The source channel.
+     * @param dst     The buffer into which bytes are to be transferred.
+     * @return The number of bytes read, possibly zero, or {@code-1} if the 
channel has reached end-of-stream
+     * @throws IOException              If some other I/O error occurs.
+     * @throws IllegalArgumentException If there is room in the given buffer.
+     */
+    private static int readToLimit(final ReadableByteChannel channel, final 
ByteBuffer dst) throws IOException {
+        if (!dst.hasRemaining()) {
+            throw new IllegalArgumentException();
+        }
+        int numRead = 0;
+        int totalRead = 0;
+        while (dst.hasRemaining()) {
+            if ((totalRead += numRead = channel.read(dst)) == IOUtils.EOF) {
+                break;
+            }
+            if (numRead == 0) {
+                // 0 may be returned from a non-blocking channel.
+                Thread.yield();
+            }
+        }
+        return totalRead;
+    }
+
     private static long size(final SeekableByteChannel channel) throws 
IOException {
         return channel != null ? channel.size() : 0;
     }
diff --git a/src/test/java/org/apache/commons/io/IOUtilsTest.java 
b/src/test/java/org/apache/commons/io/IOUtilsTest.java
index ab30742c8..79b6810d2 100644
--- a/src/test/java/org/apache/commons/io/IOUtilsTest.java
+++ b/src/test/java/org/apache/commons/io/IOUtilsTest.java
@@ -61,6 +61,7 @@
 import java.nio.file.Files;
 import java.nio.file.Path;
 import java.util.Arrays;
+import java.util.Collections;
 import java.util.List;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.function.Consumer;
@@ -692,6 +693,7 @@ public void testContentEqualsIgnoreEOL() throws Exception {
 
     @Test
     public void testContentEqualsSequenceInputStream() throws Exception {
+        // https://issues.apache.org/jira/browse/IO-866
         // not equals
         // @formatter:off
         assertFalse(IOUtils.contentEquals(
@@ -746,6 +748,24 @@ public void testContentEqualsSequenceInputStream() throws 
Exception {
                     new ByteArrayInputStream("".getBytes()),
                     new ByteArrayInputStream("ab".getBytes()))));
         // @formatter:on
+        final byte[] prefixLen32 = { 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 0, 1, 2, 3, 
4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2 };
+        final byte[] suffixLen2 = { 1, 2 };
+        final byte[] fileContents = 
"someTexts".getBytes(StandardCharsets.UTF_8);
+        Files.write(testFile.toPath(), fileContents);
+        final byte[] expected = new byte[prefixLen32.length + 
fileContents.length + suffixLen2.length];
+        System.arraycopy(prefixLen32, 0, expected, 0, prefixLen32.length);
+        System.arraycopy(fileContents, 0, expected, prefixLen32.length, 
fileContents.length);
+        System.arraycopy(suffixLen2, 0, expected, prefixLen32.length + 
fileContents.length, suffixLen2.length);
+        // @formatter:off
+        assertTrue(IOUtils.contentEquals(
+                new ByteArrayInputStream(expected),
+                new SequenceInputStream(
+                    Collections.enumeration(
+                        Arrays.asList(
+                            new ByteArrayInputStream(prefixLen32),
+                            new FileInputStream(testFile),
+                            new ByteArrayInputStream(suffixLen2))))));
+        // @formatter:on
     }
 
     @Test
diff --git a/src/test/java/org/apache/commons/io/channels/FileChannelsTest.java 
b/src/test/java/org/apache/commons/io/channels/FileChannelsTest.java
index 99a4227fc..bbecd21cf 100644
--- a/src/test/java/org/apache/commons/io/channels/FileChannelsTest.java
+++ b/src/test/java/org/apache/commons/io/channels/FileChannelsTest.java
@@ -114,12 +114,14 @@ private void testContentEquals(final String content1, 
final String content2, fin
         FileUtils.writeStringToFile(file2, content2, US_ASCII);
         // File checksums are different
         assertNotEquals(FileUtils.checksumCRC32(file1), 
FileUtils.checksumCRC32(file2));
+        // content not equals
         try (FileInputStream in1 = new FileInputStream(file1);
                 FileInputStream in2 = new FileInputStream(file2);
                 FileChannel channel1 = getChannel(in1, fileChannelType, 
bufferSize);
                 FileChannel channel2 = getChannel(in2, fileChannelType, 
half(bufferSize))) {
             assertFalse(FileChannels.contentEquals(channel1, channel2, 
bufferSize));
         }
+        // content not equals
         try (FileInputStream in1 = new FileInputStream(file1);
                 FileInputStream in2 = new FileInputStream(file2);
                 FileChannel channel1 = getChannel(in1, fileChannelType, 
bufferSize);
diff --git 
a/src/test/java/org/apache/commons/io/channels/FixedReadSizeFileChannelProxy.java
 
b/src/test/java/org/apache/commons/io/channels/FixedReadSizeFileChannelProxy.java
index 132626d30..d7f821922 100644
--- 
a/src/test/java/org/apache/commons/io/channels/FixedReadSizeFileChannelProxy.java
+++ 
b/src/test/java/org/apache/commons/io/channels/FixedReadSizeFileChannelProxy.java
@@ -22,7 +22,7 @@
 import java.nio.channels.FileChannel;
 
 /**
- * Always reads the same amount of bytes on each call or less.
+ * Always reads the same amount of bytes on each call (or less).
  */
 class FixedReadSizeFileChannelProxy extends FileChannelProxy {
 
@@ -38,26 +38,32 @@ class FixedReadSizeFileChannelProxy extends 
FileChannelProxy {
 
     @Override
     public int read(final ByteBuffer dst) throws IOException {
-        final int limit = dst.limit();
-        dst.limit(Math.min(readSize, dst.limit()));
-        final int read = super.read(dst);
-        if (read > readSize) {
+        final int saveLimit = dst.limit();
+        dst.limit(Math.min(dst.position() + readSize, dst.capacity()));
+        if (!dst.hasRemaining()) {
             throw new IllegalStateException("Programming error.");
         }
-        dst.limit(limit);
-        return read;
+        final int numRead = super.read(dst);
+        if (numRead > readSize) {
+            throw new IllegalStateException(String.format("numRead %,d > 
readSize %,d", numRead, readSize));
+        }
+        dst.limit(saveLimit);
+        return numRead;
     }
 
     @Override
     public int read(final ByteBuffer dst, final long position) throws 
IOException {
-        final int limit = dst.limit();
-        dst.limit(Math.min(readSize, dst.limit()));
-        final int read = super.read(dst, position);
-        if (read > readSize) {
+        final int saveLimit = dst.limit();
+        dst.limit(Math.min(dst.position() + readSize, dst.capacity()));
+        if (!dst.hasRemaining()) {
             throw new IllegalStateException("Programming error.");
         }
-        dst.limit(limit);
-        return read;
+        final int numRead = super.read(dst, position);
+        if (numRead > readSize) {
+            throw new IllegalStateException(String.format("numRead %,d > 
readSize %,d", numRead, readSize));
+        }
+        dst.limit(saveLimit);
+        return numRead;
     }
 
     @Override

Reply via email to