28
28
import static com .nvidia .cuvs .internal .common .Util .cudaMemcpy ;
29
29
import static com .nvidia .cuvs .internal .common .Util .CudaMemcpyKind .*;
30
30
import static com .nvidia .cuvs .internal .common .Util .prepareTensor ;
31
+ import static com .nvidia .cuvs .internal .panama .headers_h .cudaStream_t ;
31
32
import static com .nvidia .cuvs .internal .panama .headers_h .cuvsBruteForceBuild ;
32
33
import static com .nvidia .cuvs .internal .panama .headers_h .cuvsBruteForceDeserialize ;
33
34
import static com .nvidia .cuvs .internal .panama .headers_h .cuvsBruteForceIndexCreate ;
37
38
import static com .nvidia .cuvs .internal .panama .headers_h .cuvsBruteForceSerialize ;
38
39
import static com .nvidia .cuvs .internal .panama .headers_h .cuvsRMMAlloc ;
39
40
import static com .nvidia .cuvs .internal .panama .headers_h .cuvsRMMFree ;
40
- import static com .nvidia .cuvs .internal .panama .headers_h .cuvsResources_t ;
41
41
import static com .nvidia .cuvs .internal .panama .headers_h .cuvsStreamGet ;
42
42
import static com .nvidia .cuvs .internal .panama .headers_h .cuvsStreamSync ;
43
43
import static com .nvidia .cuvs .internal .panama .headers_h .omp_set_num_threads ;
44
- import static com .nvidia .cuvs .internal .panama .headers_h .cudaStream_t ;
45
44
46
- import java .io .FileInputStream ;
47
- import java .io .FileOutputStream ;
48
45
import java .io .InputStream ;
49
46
import java .io .OutputStream ;
50
47
import java .lang .foreign .Arena ;
64
61
import com .nvidia .cuvs .Dataset ;
65
62
import com .nvidia .cuvs .SearchResults ;
66
63
import com .nvidia .cuvs .internal .common .Util ;
67
- import com .nvidia .cuvs .internal .panama .cuvsBruteForceIndex ;
68
64
import com .nvidia .cuvs .internal .panama .cuvsFilter ;
69
65
70
66
/**
@@ -130,8 +126,13 @@ private void checkNotDestroyed() {
130
126
public void destroyIndex () throws Throwable {
131
127
checkNotDestroyed ();
132
128
try {
133
- int returnValue = cuvsBruteForceIndexDestroy (bruteForceIndexReference .getMemorySegment () );
129
+ int returnValue = cuvsBruteForceIndexDestroy (bruteForceIndexReference .indexPtr );
134
130
checkCuVSError (returnValue , "cuvsBruteForceIndexDestroy" );
131
+
132
+ if (bruteForceIndexReference .datasetBytes > 0 ) {
133
+ returnValue = cuvsRMMFree (resources .getHandle (), bruteForceIndexReference .datasetPtr , bruteForceIndexReference .datasetBytes );
134
+ checkCuVSError (returnValue , "cuvsRMMFree" );
135
+ }
135
136
} finally {
136
137
destroyed = true ;
137
138
}
@@ -145,7 +146,7 @@ public void destroyIndex() throws Throwable {
145
146
* @return an instance of {@link IndexReference} that holds the pointer to the
146
147
* index
147
148
*/
148
- private IndexReference build () throws Throwable {
149
+ private IndexReference build () {
149
150
try (var localArena = Arena .ofConfined ()) {
150
151
long rows = dataset != null ? dataset .size (): vectors .length ;
151
152
long cols = dataset != null ? dataset .dimensions (): (rows > 0 ? vectors [0 ].length : 0 );
@@ -155,16 +156,13 @@ private IndexReference build() throws Throwable {
155
156
Util .buildMemorySegment (resources .getArena (), vectors );
156
157
157
158
long cuvsResources = resources .getHandle ();
158
- MemorySegment stream = arena .allocate (cudaStream_t );
159
- var returnValue = cuvsStreamGet (cuvsResources , stream );
160
- checkCuVSError (returnValue , "cuvsStreamGet" );
161
159
162
160
omp_set_num_threads (bruteForceIndexParams .getNumWriterThreads ());
163
161
164
- MemorySegment datasetMemorySegment = arena .allocate (C_POINTER );
162
+ MemorySegment datasetMemorySegment = localArena .allocate (C_POINTER );
165
163
166
164
long datasetBytes = C_FLOAT_BYTE_SIZE * rows * cols ;
167
- returnValue = cuvsRMMAlloc (cuvsResources , datasetMemorySegment , datasetBytes );
165
+ var returnValue = cuvsRMMAlloc (cuvsResources , datasetMemorySegment , datasetBytes );
168
166
checkCuVSError (returnValue , "cuvsRMMAlloc" );
169
167
170
168
// IMPORTANT: this should only come AFTER cuvsRMMAlloc call
@@ -175,23 +173,20 @@ private IndexReference build() throws Throwable {
175
173
long [] datasetShape = { rows , cols };
176
174
MemorySegment datasetTensor = prepareTensor (arena , datasetMemorySegmentP , datasetShape , 2 , 32 , 2 , 2 , 1 );
177
175
178
- MemorySegment index = arena .allocate (cuvsBruteForceIndex_t );
179
-
180
- returnValue = cuvsBruteForceIndexCreate (index );
181
- checkCuVSError (returnValue , "cuvsBruteForceIndexCreate" );
176
+ var indexReference = new IndexReference (datasetMemorySegmentP , datasetBytes , createBruteForceIndex ());
182
177
183
178
returnValue = cuvsStreamSync (cuvsResources );
184
179
checkCuVSError (returnValue , "cuvsStreamSync" );
185
180
186
- returnValue = cuvsBruteForceBuild (cuvsResources , datasetTensor , 0 , 0.0f , index );
181
+ returnValue = cuvsBruteForceBuild (cuvsResources , datasetTensor , 0 , 0.0f , indexReference . indexPtr );
187
182
checkCuVSError (returnValue , "cuvsBruteForceBuild" );
188
183
189
184
returnValue = cuvsStreamSync (cuvsResources );
190
185
checkCuVSError (returnValue , "cuvsStreamSync" );
191
186
192
187
omp_set_num_threads (1 );
193
188
194
- return new IndexReference ( index ) ;
189
+ return indexReference ;
195
190
}
196
191
}
197
192
@@ -236,17 +231,14 @@ public SearchResults search(BruteForceQuery cuvsQuery) throws Throwable {
236
231
var returnValue = cuvsStreamGet (cuvsResources , stream );
237
232
checkCuVSError (returnValue , "cuvsStreamGet" );
238
233
239
- MemorySegment queriesD = arena .allocate (C_POINTER );
240
- MemorySegment neighborsD = arena .allocate (C_POINTER );
241
- MemorySegment distancesD = arena .allocate (C_POINTER );
242
- MemorySegment prefilterD = arena .allocate (C_POINTER );
243
- MemorySegment prefilterDP = MemorySegment .NULL ;
244
- long prefilterLen = 0 ;
234
+ MemorySegment queriesD = localArena .allocate (C_POINTER );
235
+ MemorySegment neighborsD = localArena .allocate (C_POINTER );
236
+ MemorySegment distancesD = localArena .allocate (C_POINTER );
245
237
246
238
long queriesBytes = C_FLOAT_BYTE_SIZE * numQueries * vectorDimension ;
247
239
long neighborsBytes = C_LONG_BYTE_SIZE * numQueries * topk ;
248
240
long distanceBytes = C_FLOAT_BYTE_SIZE * numQueries * topk ;
249
- long prefilterBytes = 0 ; // size assigned later
241
+ long prefilterBytes = 0 ;
250
242
251
243
returnValue = cuvsRMMAlloc (cuvsResources , queriesD , queriesBytes );
252
244
checkCuVSError (returnValue , "cuvsRMMAlloc" );
@@ -259,14 +251,15 @@ public SearchResults search(BruteForceQuery cuvsQuery) throws Throwable {
259
251
MemorySegment queriesDP = queriesD .get (C_POINTER , 0 );
260
252
MemorySegment neighborsDP = neighborsD .get (C_POINTER , 0 );
261
253
MemorySegment distancesDP = distancesD .get (C_POINTER , 0 );
254
+ MemorySegment prefilterDP = MemorySegment .NULL ;
262
255
263
256
cudaMemcpy (queriesDP , querySeg , queriesBytes , INFER_DIRECTION );
264
257
265
- long queriesShape [] = { numQueries , vectorDimension };
258
+ long [] queriesShape = { numQueries , vectorDimension };
266
259
MemorySegment queriesTensor = prepareTensor (arena , queriesDP , queriesShape , 2 , 32 , 2 , 2 , 1 );
267
- long neighborsShape [] = { numQueries , topk };
260
+ long [] neighborsShape = { numQueries , topk };
268
261
MemorySegment neighborsTensor = prepareTensor (arena , neighborsDP , neighborsShape , 0 , 64 , 2 , 2 , 1 );
269
- long distancesShape [] = { numQueries , topk };
262
+ long [] distancesShape = { numQueries , topk };
270
263
MemorySegment distancesTensor = prepareTensor (arena , distancesDP , distancesShape , 2 , 32 , 2 , 2 , 1 );
271
264
272
265
MemorySegment prefilter = cuvsFilter .allocate (arena );
@@ -276,13 +269,14 @@ public SearchResults search(BruteForceQuery cuvsQuery) throws Throwable {
276
269
cuvsFilter .type (prefilter , 0 ); // NO_FILTER
277
270
cuvsFilter .addr (prefilter , 0 );
278
271
} else {
279
- long prefilterShape [] = { (prefilterDataLength + 31 ) / 32 };
280
- prefilterLen = prefilterShape [0 ];
272
+ long [] prefilterShape = { (prefilterDataLength + 31 ) / 32 };
273
+
274
+ MemorySegment prefilterD = localArena .allocate (C_POINTER );
275
+ long prefilterLen = prefilterShape [0 ];
281
276
prefilterBytes = C_INT_BYTE_SIZE * prefilterLen ;
282
277
283
278
returnValue = cuvsRMMAlloc (cuvsResources , prefilterD , prefilterBytes );
284
279
checkCuVSError (returnValue , "cuvsRMMAlloc" );
285
-
286
280
prefilterDP = prefilterD .get (C_POINTER , 0 );
287
281
288
282
cudaMemcpy (prefilterDP , prefilterDataMemorySegment , prefilterBytes , HOST_TO_DEVICE );
@@ -296,7 +290,7 @@ public SearchResults search(BruteForceQuery cuvsQuery) throws Throwable {
296
290
returnValue = cuvsStreamSync (cuvsResources );
297
291
checkCuVSError (returnValue , "cuvsStreamSync" );
298
292
299
- returnValue = cuvsBruteForceSearch (cuvsResources , bruteForceIndexReference .getMemorySegment () , queriesTensor ,
293
+ returnValue = cuvsBruteForceSearch (cuvsResources , bruteForceIndexReference .indexPtr , queriesTensor ,
300
294
neighborsTensor , distancesTensor , prefilter );
301
295
checkCuVSError (returnValue , "cuvsBruteForceSearch" );
302
296
@@ -312,8 +306,10 @@ public SearchResults search(BruteForceQuery cuvsQuery) throws Throwable {
312
306
checkCuVSError (returnValue , "cuvsRMMFree" );
313
307
returnValue = cuvsRMMFree (cuvsResources , queriesDP , queriesBytes );
314
308
checkCuVSError (returnValue , "cuvsRMMFree" );
315
- returnValue = cuvsRMMFree (cuvsResources , prefilterDP , prefilterBytes );
316
- checkCuVSError (returnValue , "cuvsRMMFree" );
309
+ if (prefilterBytes > 0 ) {
310
+ returnValue = cuvsRMMFree (cuvsResources , prefilterDP , prefilterBytes );
311
+ checkCuVSError (returnValue , "cuvsRMMFree" );
312
+ }
317
313
318
314
return new BruteForceSearchResults (neighborsSequenceLayout , distancesSequenceLayout , neighborsMemorySegment ,
319
315
distancesMemorySegment , cuvsQuery .getTopK (), cuvsQuery .getMapping (), numQueries );
@@ -332,17 +328,33 @@ public void serialize(OutputStream outputStream, Path tempFile) throws Throwable
332
328
tempFile = tempFile .toAbsolutePath ();
333
329
334
330
long cuvsRes = resources .getHandle ();
335
- int returnValue = cuvsBruteForceSerialize (cuvsRes , resources .getArena ().allocateFrom (tempFile .toString ()),
336
- bruteForceIndexReference .getMemorySegment ());
337
- checkCuVSError (returnValue , "cuvsBruteForceSerialize" );
331
+ try (var localArena = Arena .ofConfined ()) {
332
+ int returnValue = cuvsBruteForceSerialize (cuvsRes , localArena .allocateFrom (tempFile .toString ()),
333
+ bruteForceIndexReference .indexPtr );
334
+ checkCuVSError (returnValue , "cuvsBruteForceSerialize" );
335
+ }
338
336
339
- try (FileInputStream fileInputStream = new FileInputStream ( tempFile . toFile () )) {
340
- fileInputStream .transferTo (outputStream );
337
+ try (var inputStream = Files . newInputStream ( tempFile )) {
338
+ inputStream .transferTo (outputStream );
341
339
} finally {
342
340
Files .deleteIfExists (tempFile );
343
341
}
344
342
}
345
343
344
+ private static MemorySegment createBruteForceIndex () {
345
+ try (var localArena = Arena .ofConfined ()) {
346
+ MemorySegment indexPtrPtr = localArena .allocate (cuvsBruteForceIndex_t );
347
+ // cuvsBruteForceIndexCreate gets a pointer to a cuvsBruteForceIndex_t, which is defined as a pointer to
348
+ // cuvsBruteForceIndex.
349
+ // It's basically a "out" parameter: the C functions will create the index and "return back" a pointer to it.
350
+ // The "out parameter" pointer is needed only for the duration of the function invocation (it could be a stack
351
+ // pointer, in C) so we allocate it from our localArena, unwrap it and return it.
352
+ var returnValue = cuvsBruteForceIndexCreate (indexPtrPtr );
353
+ checkCuVSError (returnValue , "cuvsBruteForceIndexCreate" );
354
+ return indexPtrPtr .get (cuvsBruteForceIndex_t , 0 );
355
+ }
356
+ }
357
+
346
358
/**
347
359
* Gets an instance of {@link IndexReference} by deserializing a BRUTEFORCE
348
360
* index using an {@link InputStream}.
@@ -352,16 +364,16 @@ public void serialize(OutputStream outputStream, Path tempFile) throws Throwable
352
364
*/
353
365
private IndexReference deserialize (InputStream inputStream ) throws Throwable {
354
366
checkNotDestroyed ();
355
- Path tmpIndexFile = Files .createTempFile (resources .tempDirectory (), UUID .randomUUID ().toString (), ".bf" );
356
- tmpIndexFile = tmpIndexFile .toAbsolutePath ();
357
- IndexReference indexReference = new IndexReference (resources );
367
+ Path tmpIndexFile = Files .createTempFile (resources .tempDirectory (), UUID .randomUUID ().toString (), ".bf" )
368
+ .toAbsolutePath ();
369
+ IndexReference indexReference = new IndexReference (createBruteForceIndex () );
358
370
359
- try (var in = inputStream ; FileOutputStream fileOutputStream = new FileOutputStream ( tmpIndexFile . toFile () )) {
360
- in .transferTo (fileOutputStream );
371
+ try (inputStream ; var outputStream = Files . newOutputStream ( tmpIndexFile ); var arena = Arena . ofConfined ( )) {
372
+ inputStream .transferTo (outputStream );
361
373
362
374
long cuvsRes = resources .getHandle ();
363
- int returnValue = cuvsBruteForceDeserialize (cuvsRes , resources . getArena () .allocateFrom (tmpIndexFile .toString ()),
364
- indexReference .getMemorySegment () );
375
+ int returnValue = cuvsBruteForceDeserialize (cuvsRes , arena .allocateFrom (tmpIndexFile .toString ()),
376
+ indexReference .indexPtr );
365
377
checkCuVSError (returnValue , "cuvsBruteForceDeserialize" );
366
378
367
379
} finally {
@@ -464,37 +476,24 @@ public BruteForceIndexImpl build() throws Throwable {
464
476
}
465
477
466
478
/**
467
- * Holds the memory reference to a BRUTEFORCE index.
479
+ * Holds the memory reference to a BRUTEFORCE index and its associated dataset
468
480
*/
469
- protected static class IndexReference {
481
+ private static class IndexReference {
470
482
471
- private final MemorySegment memorySegment ;
472
-
473
- /**
474
- * Constructs CagraIndexReference and allocate the MemorySegment.
475
- */
476
- protected IndexReference (CuVSResourcesImpl resources ) {
477
- memorySegment = cuvsBruteForceIndex .allocate (resources .getArena ());
478
- }
483
+ private final MemorySegment datasetPtr ;
484
+ private final long datasetBytes ;
485
+ private final MemorySegment indexPtr ;
479
486
480
- /**
481
- * Constructs BruteForceIndexReference with an instance of MemorySegment passed
482
- * as a parameter.
483
- *
484
- * @param indexMemorySegment the MemorySegment instance to use for containing
485
- * index reference
486
- */
487
- protected IndexReference (MemorySegment indexMemorySegment ) {
488
- this .memorySegment = indexMemorySegment ;
487
+ private IndexReference (MemorySegment datasetPtr , long datasetBytes , MemorySegment indexPtr ) {
488
+ this .datasetPtr = datasetPtr ;
489
+ this .datasetBytes = datasetBytes ;
490
+ this .indexPtr = indexPtr ;
489
491
}
490
492
491
- /**
492
- * Gets the instance of index MemorySegment.
493
- *
494
- * @return index MemorySegment
495
- */
496
- protected MemorySegment getMemorySegment () {
497
- return memorySegment ;
493
+ private IndexReference (MemorySegment indexPtr ) {
494
+ this .datasetPtr = MemorySegment .NULL ;
495
+ this .datasetBytes = 0 ;
496
+ this .indexPtr = indexPtr ;
498
497
}
499
498
}
500
499
}
0 commit comments