diff --git a/src/java/org/apache/cassandra/cache/ChunkCache.java b/src/java/org/apache/cassandra/cache/ChunkCache.java index 6def74cc21b2..a856a478011a 100644 --- a/src/java/org/apache/cassandra/cache/ChunkCache.java +++ b/src/java/org/apache/cassandra/cache/ChunkCache.java @@ -661,6 +661,8 @@ public BufferHolder rebuffer(long position) // chunks. } } + assert buf.offset() <= position : "Buffer offset " + buf.offset() + " must be <= requested position " + position; + assert position == source.fileLength() || buf.buffer().limit() >= buf.offset() - position : "Buffer must be non-empty for non-EOF position " + position; return buf; } catch (Throwable t) diff --git a/src/java/org/apache/cassandra/db/commitlog/EncryptedFileSegmentInputStream.java b/src/java/org/apache/cassandra/db/commitlog/EncryptedFileSegmentInputStream.java index 171c138dce35..41e93cf7571a 100644 --- a/src/java/org/apache/cassandra/db/commitlog/EncryptedFileSegmentInputStream.java +++ b/src/java/org/apache/cassandra/db/commitlog/EncryptedFileSegmentInputStream.java @@ -23,10 +23,13 @@ import java.io.DataInput; import java.nio.ByteBuffer; +import com.google.common.base.Preconditions; + import org.apache.cassandra.io.util.DataPosition; import org.apache.cassandra.io.util.File; import org.apache.cassandra.io.util.FileDataInput; import org.apache.cassandra.io.util.FileSegmentInputStream; +import org.apache.cassandra.io.util.Rebufferer; /** * Each segment of an encrypted file may contain many encrypted chunks, and each chunk needs to be individually decrypted @@ -103,7 +106,10 @@ public long bytesPastMark(DataPosition mark) public void reBuffer() { + Preconditions.checkState(!buffer.hasRemaining(), "Current buffer not exhausted, remaining bytes: %s", buffer.remaining()); totalChunkOffset += buffer.position(); buffer = chunkProvider.nextChunk(); + if (buffer == null) + buffer = Rebufferer.EMPTY.buffer(); } } diff --git a/src/java/org/apache/cassandra/db/streaming/CompressedInputStream.java b/src/java/org/apache/cassandra/db/streaming/CompressedInputStream.java index d0b1e4c3c1cc..16dad4e8f5f1 100644 --- a/src/java/org/apache/cassandra/db/streaming/CompressedInputStream.java +++ b/src/java/org/apache/cassandra/db/streaming/CompressedInputStream.java @@ -24,6 +24,7 @@ import java.util.concurrent.ThreadLocalRandom; import java.util.function.DoubleSupplier; +import com.google.common.base.Preconditions; import com.google.common.collect.Iterators; import com.google.common.primitives.Ints; @@ -110,8 +111,11 @@ public void position(long position) throws IOException @Override protected void reBuffer() throws IOException { - if (uncompressedChunkPosition < 0) - throw new IllegalStateException("position(long position) wasn't called first"); + Preconditions.checkState(!buffer.hasRemaining(), "Current buffer not exhausted, remaining bytes: %s", buffer.remaining()); + Preconditions.checkState(uncompressedChunkPosition >= 0, "position(long position) wasn't called first"); + + if (!compressedChunks.hasNext()) + return; // EOF, but we cannot signal it with throwing EOFException here because of the contract of reBuffer() /* * reBuffer() will only be called if a partition range spanning multiple (adjacent) compressed chunks @@ -120,6 +124,8 @@ protected void reBuffer() throws IOException */ loadNextChunk(); uncompressedChunkPosition += compressionParams.chunkLength(); + + assert buffer.hasRemaining() || !compressedChunks.hasNext() : "Buffer should have remaining bytes or be at EOF"; } /** diff --git a/src/java/org/apache/cassandra/hints/ChecksummedDataInput.java b/src/java/org/apache/cassandra/hints/ChecksummedDataInput.java index c838c4f9016a..626c1aedc1d1 100644 --- a/src/java/org/apache/cassandra/hints/ChecksummedDataInput.java +++ b/src/java/org/apache/cassandra/hints/ChecksummedDataInput.java @@ -191,10 +191,10 @@ public boolean checkCrc() throws IOException } @Override - public void readFully(byte[] b) throws IOException + public void readFully(byte[] b, int off, int len) throws IOException { - checkLimit(b.length); - super.readFully(b); + checkLimit(len); + super.readFully(b, off, len); } @Override @@ -207,7 +207,7 @@ public int read(byte[] b, int off, int len) throws IOException @Override protected void reBuffer() { - Preconditions.checkState(buffer.remaining() == 0); + Preconditions.checkState(!buffer.hasRemaining(), "Current buffer not exhausted, remaining bytes: %s", buffer.remaining()); updateCrc(); bufferOffset += buffer.limit(); @@ -219,6 +219,7 @@ protected void reBuffer() protected void readBuffer() { buffer.clear(); + //noinspection StatementWithEmptyBody while ((channel.read(buffer, bufferOffset)) == 0) {} buffer.flip(); } diff --git a/src/java/org/apache/cassandra/io/util/DataInputBuffer.java b/src/java/org/apache/cassandra/io/util/DataInputBuffer.java index 9df9861cca9c..10d81a923410 100644 --- a/src/java/org/apache/cassandra/io/util/DataInputBuffer.java +++ b/src/java/org/apache/cassandra/io/util/DataInputBuffer.java @@ -19,6 +19,8 @@ import java.nio.ByteBuffer; +import com.google.common.base.Preconditions; + /** * Input stream around a single ByteBuffer. */ @@ -58,7 +60,8 @@ public DataInputBuffer(byte[] buffer) @Override protected void reBuffer() { - //nope, we don't rebuffer, we are done! + Preconditions.checkState(!buffer.hasRemaining(), "reBuffer called with remaining bytes: %s", buffer.remaining()); + // nope, we don't rebuffer, we are done! } @Override diff --git a/src/java/org/apache/cassandra/io/util/FileInputStreamPlus.java b/src/java/org/apache/cassandra/io/util/FileInputStreamPlus.java index f4b3d7a8e67f..d935cb4521da 100644 --- a/src/java/org/apache/cassandra/io/util/FileInputStreamPlus.java +++ b/src/java/org/apache/cassandra/io/util/FileInputStreamPlus.java @@ -24,6 +24,8 @@ import java.nio.file.NoSuchFileException; import java.nio.file.Path; +import com.google.common.base.Preconditions; + public class FileInputStreamPlus extends RebufferingInputStream { final FileChannel channel; @@ -65,8 +67,10 @@ private FileInputStreamPlus(FileChannel channel, int bufferSize, Path path) @Override protected void reBuffer() throws IOException { + Preconditions.checkState(buffer.remaining() == 0, "Current buffer not exhausted, remaining bytes: %s", buffer.remaining()); buffer.clear(); - channel.read(buffer); + //noinspection StatementWithEmptyBody + while (channel.read(buffer) == 0) {} buffer.flip(); } diff --git a/src/java/org/apache/cassandra/io/util/MemoryInputStream.java b/src/java/org/apache/cassandra/io/util/MemoryInputStream.java index 3daa4c4d6852..47bc5ebbed58 100644 --- a/src/java/org/apache/cassandra/io/util/MemoryInputStream.java +++ b/src/java/org/apache/cassandra/io/util/MemoryInputStream.java @@ -23,6 +23,7 @@ import java.nio.ByteOrder; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; import com.google.common.primitives.Ints; import org.apache.cassandra.utils.memory.MemoryUtil; @@ -51,6 +52,8 @@ public MemoryInputStream(Memory mem, int bufferSize) @Override protected void reBuffer() throws IOException { + Preconditions.checkState(!buffer.hasRemaining(), "Current buffer not exhausted, remaining bytes: %s", buffer.remaining()); + if (offset - mem.peer >= mem.size()) return; diff --git a/src/java/org/apache/cassandra/io/util/NIODataInputStream.java b/src/java/org/apache/cassandra/io/util/NIODataInputStream.java index c75d44f1986f..fd29ea72debf 100644 --- a/src/java/org/apache/cassandra/io/util/NIODataInputStream.java +++ b/src/java/org/apache/cassandra/io/util/NIODataInputStream.java @@ -60,11 +60,11 @@ public NIODataInputStream(ReadableByteChannel channel, int bufferSize) @Override protected void reBuffer() throws IOException { - Preconditions.checkState(buffer.remaining() == 0); - buffer.clear(); + Preconditions.checkState(!buffer.hasRemaining(), "Current buffer not exhausted, remaining bytes: %s", buffer.remaining()); + buffer.clear(); + //noinspection StatementWithEmptyBody while ((channel.read(buffer)) == 0) {} - buffer.flip(); } diff --git a/src/java/org/apache/cassandra/io/util/RandomAccessReader.java b/src/java/org/apache/cassandra/io/util/RandomAccessReader.java index aa737d1f127c..a671e04e8a88 100644 --- a/src/java/org/apache/cassandra/io/util/RandomAccessReader.java +++ b/src/java/org/apache/cassandra/io/util/RandomAccessReader.java @@ -25,6 +25,7 @@ import java.nio.LongBuffer; import javax.annotation.concurrent.NotThreadSafe; +import com.google.common.base.Preconditions; import com.google.common.primitives.Ints; import org.apache.cassandra.io.compress.BufferType; @@ -61,8 +62,15 @@ public class RandomAccessReader extends RebufferingInputStream implements FileDa */ public void reBuffer() { + Preconditions.checkState(!buffer.hasRemaining(), "Current buffer not exhausted, remaining bytes: %s", buffer.remaining()); + if (isEOF()) + { + bufferHolder.release(); + bufferHolder = Rebufferer.emptyBufferHolderAt(length()); + buffer = bufferHolder.buffer(); return; + } reBufferAt(current()); } @@ -83,6 +91,7 @@ private void reBufferAt(long position) buffer = bufferHolder.buffer(); buffer.position(Ints.checkedCast(position - bufferHolder.offset())); buffer.order(order); + assert buffer.remaining() > 0: "Buffer must be non empty after rebuffering at " + position; } } diff --git a/src/java/org/apache/cassandra/io/util/RebufferingInputStream.java b/src/java/org/apache/cassandra/io/util/RebufferingInputStream.java index 4e17402abc23..ebe107a46642 100644 --- a/src/java/org/apache/cassandra/io/util/RebufferingInputStream.java +++ b/src/java/org/apache/cassandra/io/util/RebufferingInputStream.java @@ -56,15 +56,28 @@ protected RebufferingInputStream(ByteBuffer buffer, boolean validateByteOrder) this.buffer = buffer; } - /** - * Implementations must implement this method to refill the buffer. - * They can expect the buffer to be empty when this method is invoked. - * @throws IOException - */ + /// Refills the buffer with new data. + /// The buffer must be empty when this method is invoked. + /// The buffer must be filled with at least 1 byte of data unless EOF is reached. + /// + /// EOF is indicated by not writing any bytes to the buffer and leaving the buffer with no remaining content. + /// The implementations must not throw `EOFException` on EOF. + /// + /// Callers must not rely on the identity of the buffer object to stay the same after this call returns. + /// The buffer reference may be switched to a different buffer instance in order to provide new data, and the + /// previous buffer may be released if applicable. + /// The buffer reference may be switched to a static empty buffer in case of EOF, in order to release the current + /// exhausted buffer and to free up memory. + /// The buffer is not allowed to be set to null if the call to this method exits normally (no exception thrown). + /// + /// @throws IOException when data is expected but could not be read due to an I/O error + /// @throws IllegalStateException if the buffer hasn't been exhausted when this method is invoked protected abstract void reBuffer() throws IOException; + // This is final because it is a convenience method that simply delegates to readFully(byte[], int, int). + // Override that method instead if you want to change the behavior. @Override - public void readFully(byte[] b) throws IOException + public final void readFully(byte[] b) throws IOException { readFully(b, 0, b.length); } @@ -72,9 +85,21 @@ public void readFully(byte[] b) throws IOException @Override public void readFully(byte[] b, int off, int len) throws IOException { - int read = read(b, off, len); - if (read < len) - throw new EOFException("EOF after " + read + " bytes out of " + len); + // avoid int overflow + if (off < 0 || off > b.length || len < 0 || len > b.length - off) + throw new IndexOutOfBoundsException(); + + int copied = 0; + while (copied < len) + { + int read = readInternal(b, off, len - copied); + if (read == -1) + throw new EOFException("EOF after " + copied + " bytes out of " + len); + copied += read; + off += read; + } + + assert copied == len; } @Override @@ -84,29 +109,30 @@ public int read(byte[] b, int off, int len) throws IOException if (off < 0 || off > b.length || len < 0 || len > b.length - off) throw new IndexOutOfBoundsException(); + return readInternal(b, off, len); + } + + /// Reads up to `len` bytes into `b` at offset `off` from the current buffer. + /// Returns number of bytes read, or -1 if EOF is reached before reading any bytes. + /// If the buffer is empty, it will be refilled via `reBuffer()` once. + /// If EOF is not reached, reads at least one byte. + private int readInternal(byte[] b, int off, int len) throws IOException + { if (len == 0) return 0; - int copied = 0; - while (copied < len) + if (!buffer.hasRemaining()) { - int position = buffer.position(); - int remaining = buffer.limit() - position; - if (remaining == 0) - { - reBuffer(); - position = buffer.position(); - remaining = buffer.limit() - position; - if (remaining == 0) - return copied == 0 ? -1 : copied; - } - int toCopy = min(len - copied, remaining); - FastByteOperations.copy(buffer, position, b, off + copied, toCopy); - buffer.position(position + toCopy); - copied += toCopy; + reBuffer(); + if (!buffer.hasRemaining()) + return -1; // EOF } - return copied; + int toRead = min(len, buffer.remaining()); + assert toRead > 0 : "toRead must be > 0"; + FastByteOperations.copy(buffer, buffer.position(), b, off, toRead); + buffer.position(buffer.position() + toRead); + return toRead; } /** @@ -139,6 +165,8 @@ public void readFully(ByteBuffer dst) throws IOException buffer.position(position + toCopy); copied += toCopy; } + + assert copied == len; } @DontInline diff --git a/src/java/org/apache/cassandra/net/AsyncStreamingInputPlus.java b/src/java/org/apache/cassandra/net/AsyncStreamingInputPlus.java index 84fb8ac167e2..c12a08e1e68c 100644 --- a/src/java/org/apache/cassandra/net/AsyncStreamingInputPlus.java +++ b/src/java/org/apache/cassandra/net/AsyncStreamingInputPlus.java @@ -25,6 +25,7 @@ import java.util.concurrent.TimeUnit; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; import com.google.common.primitives.Ints; import io.netty.buffer.ByteBuf; @@ -109,6 +110,11 @@ public boolean append(ByteBuf buf) throws IllegalStateException @Override protected void reBuffer() throws EOFException, InputTimeoutException { + Preconditions.checkState(!isClosed, "Stream already closed"); + Preconditions.checkState(buffer != null, "Stream already closed: buffer is null"); + Preconditions.checkState(currentBuf != null, "Stream already closed: currentBuf is null"); + Preconditions.checkState(!buffer.hasRemaining(), "Current buffer not exhausted, remaining bytes: %s", buffer.remaining()); + if (queue.isEmpty()) channel.read(); @@ -129,9 +135,6 @@ protected void reBuffer() throws EOFException, InputTimeoutException if (null == next) throw new InputTimeoutException(); - if (next == Unpooled.EMPTY_BUFFER) // Unpooled.EMPTY_BUFFER is the indicator that the input is closed - throw new EOFException(); - currentBuf = next; buffer = next.nioBuffer(); } @@ -150,6 +153,8 @@ public void consume(Consumer consumer, long length) throws IOException { if (!buffer.hasRemaining()) reBuffer(); + if (!buffer.hasRemaining()) + throw new EOFException(); final int position = buffer.position(); final int limit = buffer.limit(); diff --git a/src/java/org/apache/cassandra/net/ChunkedInputPlus.java b/src/java/org/apache/cassandra/net/ChunkedInputPlus.java index 3aad8d96150e..ac7ec3a85158 100644 --- a/src/java/org/apache/cassandra/net/ChunkedInputPlus.java +++ b/src/java/org/apache/cassandra/net/ChunkedInputPlus.java @@ -19,9 +19,11 @@ import java.io.EOFException; +import com.google.common.base.Preconditions; import com.google.common.collect.Iterators; import com.google.common.collect.PeekingIterator; +import org.apache.cassandra.io.util.Rebufferer; import org.apache.cassandra.io.util.RebufferingInputStream; /** @@ -51,21 +53,27 @@ static ChunkedInputPlus of(Iterable buffers) { PeekingIterator iter = Iterators.peekingIterator(buffers.iterator()); if (!iter.hasNext()) - throw new IllegalArgumentException(); + throw new IllegalArgumentException("Cannot create ChunkedInputPlus from empty iterable"); return new ChunkedInputPlus(iter); } @Override protected void reBuffer() throws EOFException { - buffer = null; - iter.peek().release(); - iter.next(); + Preconditions.checkState(buffer != null, "Stream already closed"); + Preconditions.checkState(buffer.remaining() == 0, "Current buffer not exhausted, remaining bytes: %s", buffer.remaining()); - if (!iter.hasNext()) - throw new EOFException(); + buffer = Rebufferer.EMPTY.buffer(); + + // skip and release empty buffers because returning an empty buffer would mean EOF + while (iter.hasNext() && !iter.peek().hasRemaining()) + iter.next().release(); - buffer = iter.peek().get(); + if (iter.hasNext()) + { + buffer = iter.peek().get(); + assert buffer.hasRemaining() : "Next buffer should be non-empty"; + } } @Override diff --git a/src/java/org/apache/cassandra/streaming/compress/StreamCompressionInputStream.java b/src/java/org/apache/cassandra/streaming/compress/StreamCompressionInputStream.java index ceed532c5e48..f39d0b0e1df7 100644 --- a/src/java/org/apache/cassandra/streaming/compress/StreamCompressionInputStream.java +++ b/src/java/org/apache/cassandra/streaming/compress/StreamCompressionInputStream.java @@ -20,6 +20,8 @@ import java.io.IOException; +import com.google.common.base.Preconditions; + import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.PooledByteBufAllocator; @@ -65,6 +67,8 @@ public StreamCompressionInputStream(DataInputPlus dataInputPlus, int protocolVer @Override public void reBuffer() throws IOException { + Preconditions.checkState(!buffer.hasRemaining(), "Current buffer not exhausted, remaining bytes: %s", buffer.remaining()); + currentBuf.release(); currentBuf = deserializer.deserialize(decompressor, dataInputPlus, protocolVersion); buffer = currentBuf.nioBuffer(0, currentBuf.readableBytes()); diff --git a/test/unit/org/apache/cassandra/io/util/BufferedRandomAccessFileTest.java b/test/unit/org/apache/cassandra/io/util/BufferedRandomAccessFileTest.java index f1576ae39cf9..64fa40fe1d6e 100644 --- a/test/unit/org/apache/cassandra/io/util/BufferedRandomAccessFileTest.java +++ b/test/unit/org/apache/cassandra/io/util/BufferedRandomAccessFileTest.java @@ -32,6 +32,7 @@ import static org.apache.cassandra.Util.expectEOF; import static org.apache.cassandra.Util.expectException; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; public class BufferedRandomAccessFileTest @@ -42,6 +43,7 @@ public static void setupDD() DatabaseDescriptor.daemonInitialization(); } + @Test public void testReadAndWrite() throws Exception { @@ -346,8 +348,9 @@ public void testIsEOF() throws IOException for (int bufferSize : Arrays.asList(1, 2, 3, 5, 8, 64)) // smaller, equal, bigger buffer sizes { final byte[] target = new byte[32]; + final ByteBuffer targetBuffer = ByteBuffer.allocate(32); - // single too-large read + // single too-large read into array for (final int offset : Arrays.asList(0, 8)) { File file1 = writeTemporaryFile(new byte[16]); @@ -359,6 +362,23 @@ public void testIsEOF() throws IOException } } + // single too-large read into ByteBuffer + for (final int offset : Arrays.asList(0, 8)) + { + File file1 = writeTemporaryFile(new byte[16]); + try (FileHandle.Builder builder = new FileHandle.Builder(file1).bufferSize(bufferSize); + FileHandle fh = builder.complete(); + RandomAccessReader file = fh.createReader()) + { + expectEOF(() -> { + targetBuffer.clear(); + targetBuffer.position(offset); + file.readFully(targetBuffer); + return null; + }); + } + } + // first read is ok but eventually EOFs for (final int n : Arrays.asList(1, 2, 4, 8)) { @@ -373,6 +393,26 @@ public void testIsEOF() throws IOException }); } } + + // first read into Buffer is ok but eventually EOFs + for (final int n : Arrays.asList(1, 2, 4, 8)) + { + File file1 = writeTemporaryFile(new byte[16]); + try (FileHandle.Builder builder = new FileHandle.Builder(file1).bufferSize(bufferSize); + FileHandle fh = builder.complete(); + RandomAccessReader file = fh.createReader()) + { + expectEOF(() -> { + while (true) + { + targetBuffer.clear(); + targetBuffer.limit(n); + file.readFully(targetBuffer); + } + }); + } + } + } } diff --git a/test/unit/org/apache/cassandra/io/util/RebufferingInputStreamTest.java b/test/unit/org/apache/cassandra/io/util/RebufferingInputStreamTest.java new file mode 100644 index 000000000000..ca0c03717fc5 --- /dev/null +++ b/test/unit/org/apache/cassandra/io/util/RebufferingInputStreamTest.java @@ -0,0 +1,578 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.io.util; + +import java.io.DataInput; +import java.io.DataInputStream; +import java.io.EOFException; +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.channels.Channels; +import java.nio.channels.ReadableByteChannel; +import java.util.Arrays; +import java.util.Random; + +import com.google.common.base.Preconditions; +import com.google.common.io.LittleEndianDataInputStream; +import org.junit.Test; + +import org.apache.cassandra.utils.AssertUtil; +import org.apache.cassandra.utils.vint.VIntCoding; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +public class RebufferingInputStreamTest +{ + static final int STREAM_LEN = 1024 * 1024 + 15; + + @Test + public void testZeroLenRead() throws IOException + { + try (var r = new TestRebufferingInputStream(new RandomBytesInputStream(0, STREAM_LEN))) + { + byte[] buffer = new byte[8]; + assertEquals(0, r.read(buffer, 0, 0)); + + final ByteBuffer targetBuffer = ByteBuffer.allocate(4); + targetBuffer.limit(0); + r.readFully(targetBuffer); + assertEquals(0, targetBuffer.position()); + } + } + + @Test + public void testOutOfBounds() throws IOException + { + try (var r = new TestRebufferingInputStream(new RandomBytesInputStream(0, STREAM_LEN))) + { + byte[] buffer = new byte[8]; + assertThrows(IndexOutOfBoundsException.class, () -> r.read(buffer, 0, buffer.length + 1)); + assertThrows(IndexOutOfBoundsException.class, () -> r.read(buffer, buffer.length, 1)); + assertThrows(IndexOutOfBoundsException.class, () -> r.read(buffer, buffer.length + 1, 0)); + assertThrows(IndexOutOfBoundsException.class, () -> r.read(buffer, -1, 0)); + assertThrows(IndexOutOfBoundsException.class, () -> r.read(buffer, 0, -1)); + + assertThrows(IndexOutOfBoundsException.class, () -> r.readFully(buffer, 0, buffer.length + 1)); + assertThrows(IndexOutOfBoundsException.class, () -> r.readFully(buffer, buffer.length, 1)); + assertThrows(IndexOutOfBoundsException.class, () -> r.readFully(buffer, buffer.length + 1, 0)); + assertThrows(IndexOutOfBoundsException.class, () -> r.readFully(buffer, -1, 0)); + assertThrows(IndexOutOfBoundsException.class, () -> r.readFully(buffer, 0, -1)); + } + } + + @Test + public void testRead() throws IOException + { + final int SEED = 0; + var sliceSizeRng = new Random(1); + try ( + var ref = new RandomBytesInputStream(SEED, STREAM_LEN); + var test = new TestRebufferingInputStream(new RandomBytesInputStream(SEED, STREAM_LEN))) + { + while (true) + { + int toRead = sliceSizeRng.nextInt(16); + for (int i = 0; i < toRead; i++) + { + int refByte = ref.read(); + int testByte = test.read(); + assertEquals(refByte, testByte); + if (refByte == -1) + return; + } + } + } + } + + @Test + public void testReadIntoArray() throws IOException + { + final int READ_BUF_SIZE = 1024; + final int SEED = 0; + var dataRng = new Random(SEED); + var sliceSizeRng = new Random(1); + try (var stream = new TestRebufferingInputStream(new RandomBytesInputStream(SEED, STREAM_LEN))) + { + byte[] readBuffer = new byte[READ_BUF_SIZE]; + byte[] refBuffer = new byte[READ_BUF_SIZE]; + + while (true) + { + int toRead = sliceSizeRng.nextInt(READ_BUF_SIZE); + int read = stream.read(readBuffer, 0, toRead); + if (read == -1) // EOF + break; + + randomFill(dataRng, refBuffer, read); + assert Arrays.equals(refBuffer, 0, read, readBuffer, 0, read) + : "Read data does not match reference data"; + } + } + } + + @Test + public void testReadFullyIntoArray() throws IOException + { + final int READ_BUF_SIZE = 1024; + final int SEED = 0; + var dataRng = new Random(SEED); + var sliceSizeRng = new Random(1); + try (var stream = new TestRebufferingInputStream(new RandomBytesInputStream(SEED, STREAM_LEN))) + { + byte[] refBuffer = new byte[READ_BUF_SIZE]; + byte[] readBuffer = new byte[READ_BUF_SIZE]; + int totalRead = 0; + while (totalRead < STREAM_LEN) + { + int toRead = Math.min(sliceSizeRng.nextInt(READ_BUF_SIZE), STREAM_LEN - totalRead); + stream.readFully(readBuffer, 0, toRead); + totalRead += toRead; + randomFill(dataRng, refBuffer, toRead); + assert Arrays.equals(refBuffer, 0, toRead, readBuffer, 0, toRead) + : "Read data does not match reference data"; + } + } + } + + @Test + public void testReadFullyIntoBuffer() throws IOException + { + final int READ_BUF_SIZE = 11; // An arbitrary size that is not a multiple of the internal buffer size + final int SEED = 0; + var dataRng = new Random(SEED); + var sliceSizeRng = new Random(1); + try (var stream = new TestRebufferingInputStream(new RandomBytesInputStream(SEED, STREAM_LEN))) + { + ByteBuffer readBuffer = ByteBuffer.allocate(READ_BUF_SIZE); + byte[] refBuffer = new byte[READ_BUF_SIZE]; + int totalRead = 0; + while (totalRead < STREAM_LEN) + { + int toRead = Math.min(sliceSizeRng.nextInt(readBuffer.capacity()), STREAM_LEN - totalRead); + readBuffer.clear(); + readBuffer.limit(toRead); + stream.readFully(readBuffer); + totalRead += toRead; + randomFill(dataRng, refBuffer, toRead); + assert Arrays.equals(refBuffer, 0, toRead, readBuffer.array(), 0, toRead) + : "Read data does not match reference data"; + } + } + } + + @Test + public void testEOF() + { + assertThrows(EOFException.class, () -> { + try (var stream = new TestRebufferingInputStream(new RandomBytesInputStream(0, STREAM_LEN))) + { + for (int i = 0; i < STREAM_LEN / 128 + 1; i++) + { + stream.readFully(new byte[128]); + } + } + }); + assertThrows(EOFException.class, () -> { + try (var stream = new TestRebufferingInputStream(new RandomBytesInputStream(0, STREAM_LEN))) + { + ByteBuffer buffer = ByteBuffer.allocate(128); + for (int i = 0; i < STREAM_LEN / 128 + 1; i++) + { + buffer.clear(); + stream.readFully(buffer); + } + } + }); + } + + @Test + public void testReadWithSkipping() throws IOException + { + final int READ_BUF_SIZE = 1024; + final int SEED = 0; + var dataRng = new Random(SEED); + var opRng = new Random(1); + try (var stream = new TestRebufferingInputStream(new RandomBytesInputStream(SEED, STREAM_LEN))) + { + byte[] readBuffer = new byte[READ_BUF_SIZE]; + byte[] refBuffer = new byte[READ_BUF_SIZE]; + + while (true) + { + boolean shouldSkip = opRng.nextBoolean(); + int toRead = opRng.nextInt(READ_BUF_SIZE); + + if (shouldSkip) + { + int skipped = stream.skipBytes(toRead); + randomFill(dataRng, refBuffer, skipped); // read from the ref data stream to keep insync + } + else + { + int read = stream.read(readBuffer, 0, toRead); + if (read == -1) // EOF + break; + + randomFill(dataRng, refBuffer, read); + assert Arrays.equals(refBuffer, 0, read, readBuffer, 0, read) + : "Read data does not match reference data"; + } + } + } + } + + @Test + public void testBigEndianMultiByteReads() throws IOException + { + var SEED = 0; + for (Validator validator : VALIDATORS) + { + try (RandomBytesInputStream ref = new RandomBytesInputStream(SEED, STREAM_LEN); + DataInputStream refData = new DataInputStream(ref); + TestRebufferingInputStream testData = new TestRebufferingInputStream(new RandomBytesInputStream(SEED, STREAM_LEN))) + { + while (true) + { + try + { + validator.validate(refData, testData); + } + catch (EOFException e) + { + break; + } + } + } + } + + // Run with mixed types in a single sequence + Random validatorSelectRng = new Random(1); + try (RandomBytesInputStream ref = new RandomBytesInputStream(SEED, STREAM_LEN); + DataInputStream refData = new DataInputStream(ref); + TestRebufferingInputStream testData = new TestRebufferingInputStream(new RandomBytesInputStream(SEED, STREAM_LEN))) + { + while (true) + { + try + { + Validator validator = VALIDATORS[validatorSelectRng.nextInt(VALIDATORS.length)]; + validator.validate(refData, testData); + } + catch (EOFException e) + { + break; + } + } + } + } + + @Test + public void testLittleEndianMultiByteReads() throws IOException + { + var SEED = 0; + for (Validator validator : VALIDATORS) + { + try (RandomBytesInputStream ref = new RandomBytesInputStream(SEED, STREAM_LEN); + LittleEndianDataInputStream refData = new LittleEndianDataInputStream(ref); + TestRebufferingInputStream testData = new TestRebufferingInputStream(new RandomBytesInputStream(SEED, STREAM_LEN), + ByteOrder.LITTLE_ENDIAN)) + { + while (true) + { + try + { + validator.validate(refData, testData); + } + catch (EOFException e) + { + break; + } + } + } + } + + // Run with mixed types in a single sequence + Random validatorSelectRng = new Random(1); + try (RandomBytesInputStream ref = new RandomBytesInputStream(SEED, STREAM_LEN); + LittleEndianDataInputStream refData = new LittleEndianDataInputStream(ref); + TestRebufferingInputStream testData = new TestRebufferingInputStream(new RandomBytesInputStream(SEED, STREAM_LEN), + ByteOrder.LITTLE_ENDIAN)) + { + while (true) + { + try + { + Validator validator = VALIDATORS[validatorSelectRng.nextInt(VALIDATORS.length)]; + validator.validate(refData, testData); + } + catch (EOFException e) + { + break; + } + } + } + } + + @Test + public void testUtf8Reads() throws IOException + { + final int SEED = 0; + try (RandomUtf8InputStream ref = new RandomUtf8InputStream(SEED, STREAM_LEN); + DataInputStream refData = new DataInputStream(ref); + TestRebufferingInputStream testData = new TestRebufferingInputStream(new RandomUtf8InputStream(SEED, STREAM_LEN))) + { + while (true) + { + try + { + validate(refData::readUTF, testData::readUTF); + } + catch (EOFException e) + { + break; + } + } + } + } + + /// Helper interface to perform read validations of data of different types (short, long, double, vint etc.) + /// with the same code + interface Validator + { + /// Performs reading operation on both streams and checks if the results match. + /// Expected to throw AssertionError if the results do not match. + /// Expected to throw EOFException when the end of streams is reached. + void validate(DataInput ref, RebufferingInputStream test) throws IOException; + } + + private static final Validator[] VALIDATORS = + { + (ref, test) -> validate(ref::readByte, test::readByte), + (ref, test) -> validate(ref::readShort, test::readShort), + (ref, test) -> validate(ref::readInt, test::readInt), + (ref, test) -> validate(ref::readLong, test::readLong), + (ref, test) -> validate(ref::readFloat, test::readFloat), + (ref, test) -> validate(ref::readDouble, test::readDouble), + (ref, test) -> validate(ref::readBoolean, test::readBoolean), + (ref, test) -> validate(ref::readChar, test::readChar), + (ref, test) -> validate(ref::readUnsignedByte, test::readUnsignedByte), + (ref, test) -> validate(ref::readUnsignedShort, test::readUnsignedShort), + (ref, test) -> validate(() -> VIntCoding.readVInt(ref), test::readVInt), + (ref, test) -> validate(() -> VIntCoding.readUnsignedVInt(ref), test::readUnsignedVInt), + }; + + /// Performs reading operation on both streams and checks if the results match. + /// If the first stream hits EOF, the second stream is still read to check if it also hits EOF. + /// If both streams hit EOF, an EOFException is thrown. + /// If one stream hits EOF and another does not, then an assertion error is thrown. + private static void validate(AssertUtil.ThrowingSupplier ref, + AssertUtil.ThrowingSupplier test) throws IOException + { + EOFException eof1 = null; + EOFException eof2 = null; + T refValue = null; + T testValue = null; + + try + { + refValue = ref.get(); + } + catch (EOFException e) + { + eof1 = e; + } + catch (Throwable e) + { + throw new RuntimeException(e); + } + + try + { + testValue = test.get(); + } + catch (EOFException e) + { + eof2 = e; + } + catch (Throwable e) + { + throw new RuntimeException(e); + } + + if (eof1 != null && eof2 == null) + throw new AssertionError("Reference stream hit EOF, but test stream did not"); + if (eof1 == null && eof2 != null) + throw new AssertionError("Test stream hit EOF, but reference stream did not"); + + assertEquals(refValue, testValue); + + if (eof1 != null) + throw eof1; + } + + /// Fills the given `ByteBuffer` with random bytes from the given Random generator. + /// We are deliberately not using `rng.nextBytes()` because it does not guarantee generating + /// the same sequence of data if buffers are of random sizes (i.e. if we slice two streams of data + /// differently, we want to get the same data sequence, but `rng.nextBytes()` could generate different sequences). + public static void randomFill(Random rng, byte[] buffer, int len) + { + for (int i = 0; i < len; i++) + buffer[i] = (byte) rng.nextInt(256); + } + + /// A test RebufferingInputStream that fills the buffer with data read from another reference InputStream. + /// The buffer is filled with random sizes of data, to make sure the rebuffering logic works correctly regardless + /// of how data are sliced - e.g. the boundary between buffers may fall in the middle of a multibyte data type. + static class TestRebufferingInputStream extends RebufferingInputStream + { + static final int BUFFER_SIZE = 64; + final Random bufFillRng; + final ReadableByteChannel source; + + public TestRebufferingInputStream(InputStream source) + { + this(source, ByteOrder.BIG_ENDIAN); + } + + public TestRebufferingInputStream(InputStream source, ByteOrder order) + { + super(makeBuffer(order), order.equals(ByteOrder.BIG_ENDIAN)); + this.source = Channels.newChannel(source); + bufFillRng = new Random(0); + } + + private static ByteBuffer makeBuffer(ByteOrder order) + { + ByteBuffer buf = ByteBuffer.allocate(BUFFER_SIZE); + buf.order(order); + buf.flip(); + return buf; + } + + @Override + protected void reBuffer() throws IOException + { + Preconditions.checkState(!buffer.hasRemaining()); + int toFill = Math.max(1, bufFillRng.nextInt(buffer.capacity())); + buffer.clear(); + buffer.limit(toFill); + source.read(buffer); + buffer.flip(); + } + + @Override + public void close() throws IOException + { + source.close(); + super.close(); + } + } + + /// Data stream for producing reference input. + /// Creates a random sequence of bytes. + static class RandomBytesInputStream extends InputStream + { + final Random rng; + final int length; + + int position = 0; + + RandomBytesInputStream(int seed, int length) + { + this.rng = new Random(seed); + this.length = length; + } + + @Override + public int read() throws IOException + { + if (position >= length) + return -1; + + position++; + return rng.nextInt(256); + } + } + + /// Generates a random sequence of UTF8 strings as an InputStream. + /// The sequence may be slightly longer than the specified length because of the variable-length + /// nature of UTF8 encoding. + static class RandomUtf8InputStream extends InputStream + { + final Random rng; + final int length; + + int remaining; + + ByteBuffer buffer; + DataOutputBufferFixed dataOutput; + + RandomUtf8InputStream(int seed, int length) + { + this.rng = new Random(seed); + this.length = length; + this.remaining = length; + buffer = ByteBuffer.allocate(4096); + dataOutput = new DataOutputBufferFixed(buffer); + } + + @Override + public int read() throws IOException + { + // with UTF8 we can slightly exceed the desired stream length so remaining can get negative, + // but we don't want to cut the stream in the middle of the UTF8 string + if (remaining <= 0) + return -1; + + if (!buffer.hasRemaining()) + writeRandomUtf8String(); + + return buffer.get() & 0xFF; + } + + private void writeRandomUtf8String() throws IOException + { + dataOutput.clear(); + // The string length must be much smaller thant the buffer size, + // because UTF8 generates multiple bytes per character + var stringLen = rng.nextInt(buffer.remaining() / 8 - 4); + String str = randomString(stringLen); + dataOutput.writeUTF(str); + remaining -= dataOutput.getLength(); + buffer.flip(); + } + + private String randomString(int length) + { + StringBuilder sb = new StringBuilder(length); + for (int i = 0; i < length; i++) + { + int codePoint; + do + { + codePoint = rng.nextInt(Character.MAX_CODE_POINT + 1); + } while (Character.isSurrogate((char) codePoint)); + + sb.appendCodePoint(codePoint); + } + return sb.toString(); + } + } +} \ No newline at end of file diff --git a/test/unit/org/apache/cassandra/net/ChunkedInputPlusTest.java b/test/unit/org/apache/cassandra/net/ChunkedInputPlusTest.java index 3209759a49b9..e2215262d7f6 100644 --- a/test/unit/org/apache/cassandra/net/ChunkedInputPlusTest.java +++ b/test/unit/org/apache/cassandra/net/ChunkedInputPlusTest.java @@ -29,8 +29,6 @@ import org.junit.Test; import org.apache.cassandra.config.DatabaseDescriptor; -import org.apache.cassandra.net.ChunkedInputPlus; -import org.apache.cassandra.net.ShareableBytes; import static org.junit.Assert.*; @@ -150,6 +148,28 @@ public void testRemainder() throws IOException } } + @Test + public void testSkipEmptyBuffers() throws IOException + { + List chunks = Lists.newArrayList( + chunk(0, 0), chunk(2, 2), chunk(0, 0), chunk(3, 3), chunk(0, 0) + ); + + try (ChunkedInputPlus input = ChunkedInputPlus.of(chunks)) + { + byte[] readBytes = new byte[5]; + input.readFully(readBytes); + assertArrayEquals(new byte[] { 2, 2, 3, 3, 3 }, readBytes); + + assertEquals(0, input.remainder()); + for (ShareableBytes chunk : chunks) + { + assertTrue(chunk.isReleased()); + assertFalse(chunk.hasRemaining()); + } + } + } + private ShareableBytes chunk(int size, int fill) { ByteBuffer buffer = ByteBuffer.allocate(size);