Skip to content

Commit

Permalink
Introduces IndexInput#updateReadAdvice to change the readadvice while
Browse files Browse the repository at this point in the history
reading IndexInput

The change is needed to be able to reduce the force merge time.
Lucene99FlatVectorsReader is opened with IOContext.RANDOM, this optimizes
searches with madvise as RANDOM. For merges we need sequential access and
ability to preload pages to be able to shorten the merge time.

The change updates the ReadAdvice.SEQUENTIAL before the merge starts and reverts it
to ReadAdvice.RANDOM at the end of the merge for
Lucene99FlatVectorsReader.
  • Loading branch information
shatejas committed Nov 9, 2024
1 parent 12ca477 commit b6a8619
Show file tree
Hide file tree
Showing 8 changed files with 183 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -123,4 +123,11 @@ public abstract void search(
public KnnVectorsReader getMergeInstance() {
return this;
}

/**
* Optional: reset or close merge resources used in the reader
*
* <p>The default implementation is empty
*/
public void finishMerge() throws IOException {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,18 @@ public final void merge(MergeState mergeState) throws IOException {
}
}
}
finishMerge(mergeState);
finish();
}

private void finishMerge(MergeState mergeState) throws IOException {
for (KnnVectorsReader reader : mergeState.knnVectorsReaders) {
if (reader != null) {
reader.finishMerge();
}
}
}

/** Tracks state of one sub-reader that we are merging */
private static class FloatVectorValuesSub extends DocIDMerger.Sub {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,15 @@ public void checkIntegrity() throws IOException {
CodecUtil.checksumEntireFile(vectorData);
}

@Override
public FlatVectorsReader getMergeInstance() {
try {
return new MergeLucene99FlatVectorsReader(this);
} catch (IOException exception) {
throw new RuntimeException(exception);
}
}

private FieldEntry getFieldEntry(String field, VectorEncoding expectedEncoding) {
final FieldInfo info = fieldInfos.fieldInfo(field);
final FieldEntry fieldEntry;
Expand Down Expand Up @@ -327,4 +336,62 @@ static FieldEntry create(IndexInput input, FieldInfo info) throws IOException {
info);
}
}

private static final class MergeLucene99FlatVectorsReader extends FlatVectorsReader {

private final Lucene99FlatVectorsReader delegate;

MergeLucene99FlatVectorsReader(final Lucene99FlatVectorsReader flatVectorsReader)
throws IOException {
super(flatVectorsReader.vectorScorer);
this.delegate = flatVectorsReader;
this.delegate.vectorData.updateReadAdvice(ReadAdvice.SEQUENTIAL);
}

@Override
public RandomVectorScorer getRandomVectorScorer(String field, float[] target)
throws IOException {
return delegate.getRandomVectorScorer(field, target);
}

@Override
public RandomVectorScorer getRandomVectorScorer(String field, byte[] target)
throws IOException {
return delegate.getRandomVectorScorer(field, target);
}

@Override
public void checkIntegrity() throws IOException {
delegate.checkIntegrity();
}

@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
return delegate.getFloatVectorValues(field);
}

@Override
public ByteVectorValues getByteVectorValues(String field) throws IOException {
return delegate.getByteVectorValues(field);
}

@Override
public void close() throws IOException {
delegate.close();
}

@Override
public void finishMerge() throws IOException {
// This makes sure that the access pattern hint is reverted back since HNSW implementation
// needs it
delegate.vectorData.updateReadAdvice(ReadAdvice.RANDOM);
delegate.finishMerge();
}
;

@Override
public long ramBytesUsed() {
return delegate.ramBytesUsed();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,12 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader

private final FlatVectorsReader flatVectorsReader;
private final FieldInfos fieldInfos;
private final IntObjectHashMap<FieldEntry> fields = new IntObjectHashMap<>();
private final IntObjectHashMap<FieldEntry> fields;
private final IndexInput vectorIndex;

public Lucene99HnswVectorsReader(SegmentReadState state, FlatVectorsReader flatVectorsReader)
throws IOException {
this.fields = new IntObjectHashMap<>();
this.flatVectorsReader = flatVectorsReader;
boolean success = false;
this.fieldInfos = state.fieldInfos;
Expand Down Expand Up @@ -113,6 +114,25 @@ public Lucene99HnswVectorsReader(SegmentReadState state, FlatVectorsReader flatV
}
}

private Lucene99HnswVectorsReader(
Lucene99HnswVectorsReader reader, KnnVectorsReader flatVectorsReader) {
assert flatVectorsReader instanceof FlatVectorsReader;
this.flatVectorsReader = (FlatVectorsReader) flatVectorsReader;
this.fieldInfos = reader.fieldInfos;
this.fields = reader.fields;
this.vectorIndex = reader.vectorIndex;
}

@Override
public KnnVectorsReader getMergeInstance() {
return new Lucene99HnswVectorsReader(this, this.flatVectorsReader.getMergeInstance());
}

@Override
public void finishMerge() throws IOException {
flatVectorsReader.finishMerge();
}

private static IndexInput openDataInput(
SegmentReadState state,
int versionMeta,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,27 @@ public FieldsReader(final SegmentReadState readState) throws IOException {
}
}

private FieldsReader(final FieldsReader fieldsReader) {
this.fieldInfos = fieldsReader.fieldInfos;
for (FieldInfo fi : this.fieldInfos) {
if (fi.hasVectorValues() && fieldsReader.fields.containsKey(fi.number)) {
this.fields.put(fi.number, fieldsReader.fields.get(fi.number).getMergeInstance());
}
}
}

@Override
public KnnVectorsReader getMergeInstance() {
return new FieldsReader(this);
}

@Override
public void finishMerge() throws IOException {
for (ObjectCursor<KnnVectorsReader> knnVectorReader : fields.values()) {
knnVectorReader.value.finishMerge();
}
}

/**
* Return the underlying VectorReader for the given field
*
Expand Down
8 changes: 8 additions & 0 deletions lucene/core/src/java/org/apache/lucene/store/IndexInput.java
Original file line number Diff line number Diff line change
Expand Up @@ -226,4 +226,12 @@ public String toString() {
* @param length the number of bytes to prefetch
*/
public void prefetch(long offset, long length) throws IOException {}

/**
* Optional method: Give a hint to this input about the change in read access pattern. IndexInput
* implementations may take advantage of this hint to optimize reads from storage.
*
* <p>The default implementation is a no-op.
*/
public void updateReadAdvice(ReadAdvice readAdvice) throws IOException {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,19 @@ public void prefetch(long offset, long length) throws IOException {
});
}

public void updateReadAdvice(ReadAdvice readAdvice) throws IOException {
if (NATIVE_ACCESS.isEmpty()) {
return;
}
final NativeAccess nativeAccess = NATIVE_ACCESS.get();

long offset = 0;
for (MemorySegment seg : segments) {
advise(offset, seg.byteSize(), segment -> nativeAccess.madvise(segment, readAdvice));
offset += seg.byteSize();
}
}

void advise(long offset, long length, IOConsumer<MemorySegment> advice) throws IOException {
if (NATIVE_ACCESS.isEmpty()) {
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.store.RandomAccessInput;
import org.apache.lucene.store.ReadAdvice;
import org.apache.lucene.tests.mockfile.ExtrasFS;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.tests.util.TestUtil;
Expand Down Expand Up @@ -1554,6 +1555,42 @@ public void testPrefetchOnSlice() throws IOException {
doTestPrefetch(TestUtil.nextInt(random(), 1, 1024));
}

public void testUpdateReadAdvice() throws IOException {
try (Directory dir = getDirectory(createTempDir("testUpdateReadAdvice"))) {
final int totalLength = TestUtil.nextInt(random(), 16384, 65536);
byte[] arr = new byte[totalLength];
random().nextBytes(arr);
try (IndexOutput out = dir.createOutput("temp.bin", IOContext.DEFAULT)) {
out.writeBytes(arr, arr.length);
}

try (IndexInput orig = dir.openInput("temp.bin", IOContext.DEFAULT)) {
IndexInput in = random().nextBoolean() ? orig.clone() : orig;
// Read advice updated at start
orig.updateReadAdvice(randomReadAdvice());
for (int i = 0; i < totalLength; i++) {
int offset = TestUtil.nextInt(random(), 0, (int) in.length() - 1);
in.seek(offset);
assertEquals(arr[offset], in.readByte());
}

// Updating readAdvice in the middle
for (int i = 0; i < 10_000; ++i) {
int offset = TestUtil.nextInt(random(), 0, (int) in.length() - 1);
in.seek(offset);
assertEquals(arr[offset], in.readByte());
if (random().nextBoolean()) {
orig.updateReadAdvice(randomReadAdvice());
}
}
}
}
}

private ReadAdvice randomReadAdvice() {
return ReadAdvice.values()[TestUtil.nextInt(random(), 0, ReadAdvice.values().length - 1)];
}

private void doTestPrefetch(int startOffset) throws IOException {
try (Directory dir = getDirectory(createTempDir())) {
final int totalLength = startOffset + TestUtil.nextInt(random(), 16384, 65536);
Expand Down

0 comments on commit b6a8619

Please sign in to comment.