diff --git a/thirdparty/faiss/faiss/gpu/impl/Distance.cu b/thirdparty/faiss/faiss/gpu/impl/Distance.cu index 688b619db..c408cad10 100644 --- a/thirdparty/faiss/faiss/gpu/impl/Distance.cu +++ b/thirdparty/faiss/faiss/gpu/impl/Distance.cu @@ -367,10 +367,13 @@ void runDistance( k, streams[curStream]); } else { + // bitset need to match real idx + auto batchBitsetView = bitset.narrow( + 0, int(j / 8), bitset.getSize(0) - int(j / 8)); // Write into the intermediate output runBlockSelect( distanceBufView, - bitset, + batchBitsetView, outDistanceBufColView, outIndexBufColView, true, diff --git a/thirdparty/faiss/faiss/gpu/utils/BlockSelectKernel.cuh b/thirdparty/faiss/faiss/gpu/utils/BlockSelectKernel.cuh index d0e84f242..f185d8c84 100644 --- a/thirdparty/faiss/faiss/gpu/utils/BlockSelectKernel.cuh +++ b/thirdparty/faiss/faiss/gpu/utils/BlockSelectKernel.cuh @@ -237,7 +237,9 @@ __global__ void blockSelectPair( bool bitsetEmpty = (bitset.getSize(0) == 0); for (; i < limit; i += ThreadsPerBlock) { - if (bitsetEmpty || (!(bitset[i >> 3] & (0x1 << (i & 0x7))))) { + // bitset need to match real idx + if (bitsetEmpty || + (!(bitset[inV[row][i] >> 3] & (0x1 << (inV[row][i] & 0x7))))) { heap.addThreadQ(*inKStart, *inVStart); } heap.checkThreadQ(); @@ -248,7 +250,9 @@ __global__ void blockSelectPair( // Handle last remainder fraction of a warp of elements if (i < inK.getSize(1)) { - if (bitsetEmpty || (!(bitset[i >> 3] & (0x1 << (i & 0x7))))) { + // bitset need to match real idx + if (bitsetEmpty || + (!(bitset[inV[row][i] >> 3] & (0x1 << (inV[row][i] & 0x7))))) { heap.addThreadQ(*inKStart, *inVStart); } }