Skip to content

Commit b26ff62

Browse files
committed
Avoid double free in index creation/destruction; add tests for specific parts
1 parent ebecaf9 commit b26ff62

File tree

4 files changed

+246
-192
lines changed

4 files changed

+246
-192
lines changed

java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/BruteForceIndexImpl.java

Lines changed: 71 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import static com.nvidia.cuvs.internal.common.Util.cudaMemcpy;
2929
import static com.nvidia.cuvs.internal.common.Util.CudaMemcpyKind.*;
3030
import static com.nvidia.cuvs.internal.common.Util.prepareTensor;
31+
import static com.nvidia.cuvs.internal.panama.headers_h.cudaStream_t;
3132
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsBruteForceBuild;
3233
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsBruteForceDeserialize;
3334
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsBruteForceIndexCreate;
@@ -37,14 +38,10 @@
3738
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsBruteForceSerialize;
3839
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsRMMAlloc;
3940
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsRMMFree;
40-
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsResources_t;
4141
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsStreamGet;
4242
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsStreamSync;
4343
import static com.nvidia.cuvs.internal.panama.headers_h.omp_set_num_threads;
44-
import static com.nvidia.cuvs.internal.panama.headers_h.cudaStream_t;
4544

46-
import java.io.FileInputStream;
47-
import java.io.FileOutputStream;
4845
import java.io.InputStream;
4946
import java.io.OutputStream;
5047
import java.lang.foreign.Arena;
@@ -64,7 +61,6 @@
6461
import com.nvidia.cuvs.Dataset;
6562
import com.nvidia.cuvs.SearchResults;
6663
import com.nvidia.cuvs.internal.common.Util;
67-
import com.nvidia.cuvs.internal.panama.cuvsBruteForceIndex;
6864
import com.nvidia.cuvs.internal.panama.cuvsFilter;
6965

7066
/**
@@ -130,8 +126,13 @@ private void checkNotDestroyed() {
130126
public void destroyIndex() throws Throwable {
131127
checkNotDestroyed();
132128
try {
133-
int returnValue = cuvsBruteForceIndexDestroy(bruteForceIndexReference.getMemorySegment());
129+
int returnValue = cuvsBruteForceIndexDestroy(bruteForceIndexReference.indexPtr);
134130
checkCuVSError(returnValue, "cuvsBruteForceIndexDestroy");
131+
132+
if (bruteForceIndexReference.datasetBytes > 0) {
133+
returnValue = cuvsRMMFree(resources.getHandle(), bruteForceIndexReference.datasetPtr, bruteForceIndexReference.datasetBytes);
134+
checkCuVSError(returnValue, "cuvsRMMFree");
135+
}
135136
} finally {
136137
destroyed = true;
137138
}
@@ -145,7 +146,7 @@ public void destroyIndex() throws Throwable {
145146
* @return an instance of {@link IndexReference} that holds the pointer to the
146147
* index
147148
*/
148-
private IndexReference build() throws Throwable {
149+
private IndexReference build() {
149150
try (var localArena = Arena.ofConfined()) {
150151
long rows = dataset != null? dataset.size(): vectors.length;
151152
long cols = dataset != null? dataset.dimensions(): (rows > 0 ? vectors[0].length : 0);
@@ -155,16 +156,13 @@ private IndexReference build() throws Throwable {
155156
Util.buildMemorySegment(resources.getArena(), vectors);
156157

157158
long cuvsResources = resources.getHandle();
158-
MemorySegment stream = arena.allocate(cudaStream_t);
159-
var returnValue = cuvsStreamGet(cuvsResources, stream);
160-
checkCuVSError(returnValue, "cuvsStreamGet");
161159

162160
omp_set_num_threads(bruteForceIndexParams.getNumWriterThreads());
163161

164-
MemorySegment datasetMemorySegment = arena.allocate(C_POINTER);
162+
MemorySegment datasetMemorySegment = localArena.allocate(C_POINTER);
165163

166164
long datasetBytes = C_FLOAT_BYTE_SIZE * rows * cols;
167-
returnValue = cuvsRMMAlloc(cuvsResources, datasetMemorySegment, datasetBytes);
165+
var returnValue = cuvsRMMAlloc(cuvsResources, datasetMemorySegment, datasetBytes);
168166
checkCuVSError(returnValue, "cuvsRMMAlloc");
169167

170168
// IMPORTANT: this should only come AFTER cuvsRMMAlloc call
@@ -175,23 +173,20 @@ private IndexReference build() throws Throwable {
175173
long[] datasetShape = { rows, cols };
176174
MemorySegment datasetTensor = prepareTensor(arena, datasetMemorySegmentP, datasetShape, 2, 32, 2, 2, 1);
177175

178-
MemorySegment index = arena.allocate(cuvsBruteForceIndex_t);
179-
180-
returnValue = cuvsBruteForceIndexCreate(index);
181-
checkCuVSError(returnValue, "cuvsBruteForceIndexCreate");
176+
var indexReference = new IndexReference(datasetMemorySegmentP, datasetBytes, createBruteForceIndex());
182177

183178
returnValue = cuvsStreamSync(cuvsResources);
184179
checkCuVSError(returnValue, "cuvsStreamSync");
185180

186-
returnValue = cuvsBruteForceBuild(cuvsResources, datasetTensor, 0, 0.0f, index);
181+
returnValue = cuvsBruteForceBuild(cuvsResources, datasetTensor, 0, 0.0f, indexReference.indexPtr);
187182
checkCuVSError(returnValue, "cuvsBruteForceBuild");
188183

189184
returnValue = cuvsStreamSync(cuvsResources);
190185
checkCuVSError(returnValue, "cuvsStreamSync");
191186

192187
omp_set_num_threads(1);
193188

194-
return new IndexReference(index);
189+
return indexReference;
195190
}
196191
}
197192

@@ -236,17 +231,14 @@ public SearchResults search(BruteForceQuery cuvsQuery) throws Throwable {
236231
var returnValue = cuvsStreamGet(cuvsResources, stream);
237232
checkCuVSError(returnValue, "cuvsStreamGet");
238233

239-
MemorySegment queriesD = arena.allocate(C_POINTER);
240-
MemorySegment neighborsD = arena.allocate(C_POINTER);
241-
MemorySegment distancesD = arena.allocate(C_POINTER);
242-
MemorySegment prefilterD = arena.allocate(C_POINTER);
243-
MemorySegment prefilterDP = MemorySegment.NULL;
244-
long prefilterLen = 0;
234+
MemorySegment queriesD = localArena.allocate(C_POINTER);
235+
MemorySegment neighborsD = localArena.allocate(C_POINTER);
236+
MemorySegment distancesD = localArena.allocate(C_POINTER);
245237

246238
long queriesBytes = C_FLOAT_BYTE_SIZE * numQueries * vectorDimension;
247239
long neighborsBytes = C_LONG_BYTE_SIZE * numQueries * topk;
248240
long distanceBytes = C_FLOAT_BYTE_SIZE * numQueries * topk;
249-
long prefilterBytes = 0; // size assigned later
241+
long prefilterBytes = 0;
250242

251243
returnValue = cuvsRMMAlloc(cuvsResources, queriesD, queriesBytes);
252244
checkCuVSError(returnValue, "cuvsRMMAlloc");
@@ -259,14 +251,15 @@ public SearchResults search(BruteForceQuery cuvsQuery) throws Throwable {
259251
MemorySegment queriesDP = queriesD.get(C_POINTER, 0);
260252
MemorySegment neighborsDP = neighborsD.get(C_POINTER, 0);
261253
MemorySegment distancesDP = distancesD.get(C_POINTER, 0);
254+
MemorySegment prefilterDP = MemorySegment.NULL;
262255

263256
cudaMemcpy(queriesDP, querySeg, queriesBytes, INFER_DIRECTION);
264257

265-
long queriesShape[] = { numQueries, vectorDimension };
258+
long[] queriesShape = { numQueries, vectorDimension };
266259
MemorySegment queriesTensor = prepareTensor(arena, queriesDP, queriesShape, 2, 32, 2, 2, 1);
267-
long neighborsShape[] = { numQueries, topk };
260+
long[] neighborsShape = { numQueries, topk };
268261
MemorySegment neighborsTensor = prepareTensor(arena, neighborsDP, neighborsShape, 0, 64, 2, 2, 1);
269-
long distancesShape[] = { numQueries, topk };
262+
long[] distancesShape = { numQueries, topk };
270263
MemorySegment distancesTensor = prepareTensor(arena, distancesDP, distancesShape, 2, 32, 2, 2, 1);
271264

272265
MemorySegment prefilter = cuvsFilter.allocate(arena);
@@ -276,13 +269,14 @@ public SearchResults search(BruteForceQuery cuvsQuery) throws Throwable {
276269
cuvsFilter.type(prefilter, 0); // NO_FILTER
277270
cuvsFilter.addr(prefilter, 0);
278271
} else {
279-
long prefilterShape[] = { (prefilterDataLength + 31) / 32 };
280-
prefilterLen = prefilterShape[0];
272+
long[] prefilterShape = { (prefilterDataLength + 31) / 32 };
273+
274+
MemorySegment prefilterD = localArena.allocate(C_POINTER);
275+
long prefilterLen = prefilterShape[0];
281276
prefilterBytes = C_INT_BYTE_SIZE * prefilterLen;
282277

283278
returnValue = cuvsRMMAlloc(cuvsResources, prefilterD, prefilterBytes);
284279
checkCuVSError(returnValue, "cuvsRMMAlloc");
285-
286280
prefilterDP = prefilterD.get(C_POINTER, 0);
287281

288282
cudaMemcpy(prefilterDP, prefilterDataMemorySegment, prefilterBytes, HOST_TO_DEVICE);
@@ -296,7 +290,7 @@ public SearchResults search(BruteForceQuery cuvsQuery) throws Throwable {
296290
returnValue = cuvsStreamSync(cuvsResources);
297291
checkCuVSError(returnValue, "cuvsStreamSync");
298292

299-
returnValue = cuvsBruteForceSearch(cuvsResources, bruteForceIndexReference.getMemorySegment(), queriesTensor,
293+
returnValue = cuvsBruteForceSearch(cuvsResources, bruteForceIndexReference.indexPtr, queriesTensor,
300294
neighborsTensor, distancesTensor, prefilter);
301295
checkCuVSError(returnValue, "cuvsBruteForceSearch");
302296

@@ -312,8 +306,10 @@ public SearchResults search(BruteForceQuery cuvsQuery) throws Throwable {
312306
checkCuVSError(returnValue, "cuvsRMMFree");
313307
returnValue = cuvsRMMFree(cuvsResources, queriesDP, queriesBytes);
314308
checkCuVSError(returnValue, "cuvsRMMFree");
315-
returnValue = cuvsRMMFree(cuvsResources, prefilterDP, prefilterBytes);
316-
checkCuVSError(returnValue, "cuvsRMMFree");
309+
if (prefilterBytes > 0) {
310+
returnValue = cuvsRMMFree(cuvsResources, prefilterDP, prefilterBytes);
311+
checkCuVSError(returnValue, "cuvsRMMFree");
312+
}
317313

318314
return new BruteForceSearchResults(neighborsSequenceLayout, distancesSequenceLayout, neighborsMemorySegment,
319315
distancesMemorySegment, cuvsQuery.getTopK(), cuvsQuery.getMapping(), numQueries);
@@ -332,17 +328,33 @@ public void serialize(OutputStream outputStream, Path tempFile) throws Throwable
332328
tempFile = tempFile.toAbsolutePath();
333329

334330
long cuvsRes = resources.getHandle();
335-
int returnValue = cuvsBruteForceSerialize(cuvsRes, resources.getArena().allocateFrom(tempFile.toString()),
336-
bruteForceIndexReference.getMemorySegment());
337-
checkCuVSError(returnValue, "cuvsBruteForceSerialize");
331+
try (var localArena = Arena.ofConfined()) {
332+
int returnValue = cuvsBruteForceSerialize(cuvsRes, localArena.allocateFrom(tempFile.toString()),
333+
bruteForceIndexReference.indexPtr);
334+
checkCuVSError(returnValue, "cuvsBruteForceSerialize");
335+
}
338336

339-
try (FileInputStream fileInputStream = new FileInputStream(tempFile.toFile())) {
340-
fileInputStream.transferTo(outputStream);
337+
try (var inputStream = Files.newInputStream(tempFile)) {
338+
inputStream.transferTo(outputStream);
341339
} finally {
342340
Files.deleteIfExists(tempFile);
343341
}
344342
}
345343

344+
private static MemorySegment createBruteForceIndex() {
345+
try (var localArena = Arena.ofConfined()) {
346+
MemorySegment indexPtrPtr = localArena.allocate(cuvsBruteForceIndex_t);
347+
// cuvsBruteForceIndexCreate gets a pointer to a cuvsBruteForceIndex_t, which is defined as a pointer to
348+
// cuvsBruteForceIndex.
349+
// It's basically a "out" parameter: the C functions will create the index and "return back" a pointer to it.
350+
// The "out parameter" pointer is needed only for the duration of the function invocation (it could be a stack
351+
// pointer, in C) so we allocate it from our localArena, unwrap it and return it.
352+
var returnValue = cuvsBruteForceIndexCreate(indexPtrPtr);
353+
checkCuVSError(returnValue, "cuvsBruteForceIndexCreate");
354+
return indexPtrPtr.get(cuvsBruteForceIndex_t, 0);
355+
}
356+
}
357+
346358
/**
347359
* Gets an instance of {@link IndexReference} by deserializing a BRUTEFORCE
348360
* index using an {@link InputStream}.
@@ -352,16 +364,16 @@ public void serialize(OutputStream outputStream, Path tempFile) throws Throwable
352364
*/
353365
private IndexReference deserialize(InputStream inputStream) throws Throwable {
354366
checkNotDestroyed();
355-
Path tmpIndexFile = Files.createTempFile(resources.tempDirectory(), UUID.randomUUID().toString(), ".bf");
356-
tmpIndexFile = tmpIndexFile.toAbsolutePath();
357-
IndexReference indexReference = new IndexReference(resources);
367+
Path tmpIndexFile = Files.createTempFile(resources.tempDirectory(), UUID.randomUUID().toString(), ".bf")
368+
.toAbsolutePath();
369+
IndexReference indexReference = new IndexReference(createBruteForceIndex());
358370

359-
try (var in = inputStream; FileOutputStream fileOutputStream = new FileOutputStream(tmpIndexFile.toFile())) {
360-
in.transferTo(fileOutputStream);
371+
try (inputStream; var outputStream = Files.newOutputStream(tmpIndexFile); var arena = Arena.ofConfined()) {
372+
inputStream.transferTo(outputStream);
361373

362374
long cuvsRes = resources.getHandle();
363-
int returnValue = cuvsBruteForceDeserialize(cuvsRes, resources.getArena().allocateFrom(tmpIndexFile.toString()),
364-
indexReference.getMemorySegment());
375+
int returnValue = cuvsBruteForceDeserialize(cuvsRes, arena.allocateFrom(tmpIndexFile.toString()),
376+
indexReference.indexPtr);
365377
checkCuVSError(returnValue, "cuvsBruteForceDeserialize");
366378

367379
} finally {
@@ -464,37 +476,24 @@ public BruteForceIndexImpl build() throws Throwable {
464476
}
465477

466478
/**
467-
* Holds the memory reference to a BRUTEFORCE index.
479+
* Holds the memory reference to a BRUTEFORCE index and its associated dataset
468480
*/
469-
protected static class IndexReference {
481+
private static class IndexReference {
470482

471-
private final MemorySegment memorySegment;
472-
473-
/**
474-
* Constructs CagraIndexReference and allocate the MemorySegment.
475-
*/
476-
protected IndexReference(CuVSResourcesImpl resources) {
477-
memorySegment = cuvsBruteForceIndex.allocate(resources.getArena());
478-
}
483+
private final MemorySegment datasetPtr;
484+
private final long datasetBytes;
485+
private final MemorySegment indexPtr;
479486

480-
/**
481-
* Constructs BruteForceIndexReference with an instance of MemorySegment passed
482-
* as a parameter.
483-
*
484-
* @param indexMemorySegment the MemorySegment instance to use for containing
485-
* index reference
486-
*/
487-
protected IndexReference(MemorySegment indexMemorySegment) {
488-
this.memorySegment = indexMemorySegment;
487+
private IndexReference(MemorySegment datasetPtr, long datasetBytes, MemorySegment indexPtr) {
488+
this.datasetPtr = datasetPtr;
489+
this.datasetBytes = datasetBytes;
490+
this.indexPtr = indexPtr;
489491
}
490492

491-
/**
492-
* Gets the instance of index MemorySegment.
493-
*
494-
* @return index MemorySegment
495-
*/
496-
protected MemorySegment getMemorySegment() {
497-
return memorySegment;
493+
private IndexReference(MemorySegment indexPtr) {
494+
this.datasetPtr = MemorySegment.NULL;
495+
this.datasetBytes = 0;
496+
this.indexPtr = indexPtr;
498497
}
499498
}
500499
}

0 commit comments

Comments
 (0)