diff --git a/src/main/java/org/apache/commons/compress/utils/SeekableInMemoryByteChannel.java b/src/main/java/org/apache/commons/compress/utils/SeekableInMemoryByteChannel.java index 2f998c0e26e..df709321e0d 100644 --- a/src/main/java/org/apache/commons/compress/utils/SeekableInMemoryByteChannel.java +++ b/src/main/java/org/apache/commons/compress/utils/SeekableInMemoryByteChannel.java @@ -47,7 +47,7 @@ public class SeekableInMemoryByteChannel implements SeekableByteChannel { private static final int NAIVE_RESIZE_LIMIT = Integer.MAX_VALUE >> 1; private byte[] data; private final AtomicBoolean closed = new AtomicBoolean(); - private int position; + private long position; private int size; /** @@ -113,25 +113,28 @@ public long position() throws ClosedChannelException { @Override public SeekableByteChannel position(final long newPosition) throws IOException { ensureOpen(); - if (newPosition < 0L || newPosition > Integer.MAX_VALUE) { - throw new IllegalArgumentException(String.format("Position must be in range [0..%,d]: %,d", Integer.MAX_VALUE, newPosition)); + if (newPosition < 0L) { + throw new IllegalArgumentException(String.format("New position is negative: %,d", newPosition)); } - position = (int) newPosition; + position = newPosition; return this; } @Override public int read(final ByteBuffer buf) throws IOException { ensureOpen(); + if (position > Integer.MAX_VALUE) { + return -1; + } int wanted = buf.remaining(); - final int possible = size - position; + final int possible = size - (int) position; if (possible <= 0) { return -1; } if (wanted > possible) { wanted = possible; } - buf.put(data, position, wanted); + buf.put(data, (int) position, wanted); position += wanted; return wanted; } @@ -160,14 +163,14 @@ public long size() throws ClosedChannelException { @Override public SeekableByteChannel truncate(final long newSize) throws ClosedChannelException { ensureOpen(); - if (newSize < 0L || newSize > Integer.MAX_VALUE) { - throw new IllegalArgumentException("Size must be range [0.." + Integer.MAX_VALUE + "]"); + if (newSize < 0L) { + throw new IllegalArgumentException(String.format("New size is negative: %,d", newSize)); } if (size > newSize) { size = (int) newSize; } if (position > newSize) { - position = (int) newSize; + position = newSize; } return this; } @@ -175,21 +178,27 @@ public SeekableByteChannel truncate(final long newSize) throws ClosedChannelExce @Override public int write(final ByteBuffer b) throws IOException { ensureOpen(); + if (position > Integer.MAX_VALUE) { + throw new IOException("position > Integer.MAX_VALUE"); + } int wanted = b.remaining(); - final int possibleWithoutResize = size - position; + // intPos <= Integer.MAX_VALUE + int intPos = (int) position; + final int possibleWithoutResize = size - intPos; if (wanted > possibleWithoutResize) { - final int newSize = position + wanted; + final int newSize = intPos + wanted; if (newSize < 0) { // overflow resize(Integer.MAX_VALUE); - wanted = Integer.MAX_VALUE - position; + wanted = Integer.MAX_VALUE - intPos; } else { resize(newSize); } } - b.get(data, position, wanted); - position += wanted; - if (size < position) { - size = position; + b.get(data, intPos, wanted); + // intPos + wanted is at most (Integer.MAX_VALUE - intPos) + intPos + position = intPos += wanted; + if (size < intPos) { + size = intPos; } return wanted; } diff --git a/src/test/java/org/apache/commons/compress/utils/SeekableInMemoryByteChannelTest.java b/src/test/java/org/apache/commons/compress/utils/SeekableInMemoryByteChannelTest.java index 711ca6d2aac..d1f96943840 100644 --- a/src/test/java/org/apache/commons/compress/utils/SeekableInMemoryByteChannelTest.java +++ b/src/test/java/org/apache/commons/compress/utils/SeekableInMemoryByteChannelTest.java @@ -152,16 +152,33 @@ void testShouldThrowExceptionOnWritingToClosedChannel() { } @Test - void testShouldThrowExceptionWhenSettingIncorrectPosition() { + void testShouldThrowWhenSettingIncorrectPosition() throws IOException { try (SeekableInMemoryByteChannel c = new SeekableInMemoryByteChannel()) { - assertThrows(IllegalArgumentException.class, () -> c.position(Integer.MAX_VALUE + 1L)); + final ByteBuffer buffer = ByteBuffer.allocate(1); + c.position(c.size() + 1); + assertEquals(c.size() + 1, c.position()); + assertEquals(-1, c.read(buffer)); + c.position(Integer.MAX_VALUE + 1L); + assertEquals(Integer.MAX_VALUE + 1L, c.position()); + assertEquals(-1, c.read(buffer)); + assertThrows(IOException.class, () -> c.write(buffer)); + assertThrows(IllegalArgumentException.class, () -> c.position(-1)); + assertThrows(IllegalArgumentException.class, () -> c.position(Integer.MIN_VALUE)); + assertThrows(IllegalArgumentException.class, () -> c.position(Long.MIN_VALUE)); } } @Test - void testShouldThrowExceptionWhenTruncatingToIncorrectSize() { + void testShouldThrowWhenTruncatingToIncorrectSize() throws IOException { try (SeekableInMemoryByteChannel c = new SeekableInMemoryByteChannel()) { - assertThrows(IllegalArgumentException.class, () -> c.truncate(Integer.MAX_VALUE + 1L)); + final ByteBuffer buffer = ByteBuffer.allocate(1); + c.truncate(c.size() + 1); + assertEquals(1, c.read(buffer)); + c.truncate(Integer.MAX_VALUE + 1L); + assertEquals(0, c.read(buffer)); + assertThrows(IllegalArgumentException.class, () -> c.truncate(-1)); + assertThrows(IllegalArgumentException.class, () -> c.truncate(Integer.MIN_VALUE)); + assertThrows(IllegalArgumentException.class, () -> c.truncate(Long.MIN_VALUE)); } }