diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/PQDecoder.java b/jvector-base/src/main/java/io/github/jbellis/jvector/pq/PQDecoder.java index 0be0e088..e417d50e 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/PQDecoder.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/pq/PQDecoder.java @@ -131,18 +131,10 @@ public float similarityTo(int node2) { } protected float decodedCosine(int node2) { - float sum = 0.0f; - float aMag = 0.0f; ByteSequence encoded = cv.get(node2); - for (int m = 0; m < encoded.length(); ++m) { - int centroidIndex = Byte.toUnsignedInt(encoded.get(m)); - sum += partialSums.get((m * cv.pq.getClusterCount()) + centroidIndex); - aMag += aMagnitude.get((m * cv.pq.getClusterCount()) + centroidIndex); - } - - return (float) (sum / Math.sqrt(aMag * bMagnitude)); + return VectorUtil.decodedCosineSimilarity(encoded, cv.pq.getClusterCount(), partialSums, aMagnitude, bMagnitude); } } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java index a6d87807..d860cc1b 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java @@ -194,4 +194,8 @@ public static float max(VectorFloat v) { public static float min(VectorFloat v) { return impl.min(v); } + + public static float decodedCosineSimilarity(ByteSequence encoded, int clusterCount, VectorFloat partialSums, VectorFloat aMagnitude, float bMagnitude) { + return impl.decodedCosineSimilarity(encoded, clusterCount, partialSums, aMagnitude, bMagnitude); + } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java index 46cb4f18..e3e2953e 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java @@ -199,4 +199,18 @@ default void bulkShuffleQuantizedSimilarityCosine(ByteSequence shuffles, int float max(VectorFloat v); float min(VectorFloat v); + default float decodedCosineSimilarity(ByteSequence encoded, int clusterCount, VectorFloat partialSums, VectorFloat aMagnitude, float bMagnitude) + { + float sum = 0.0f; + float aMag = 0.0f; + + for (int m = 0; m < encoded.length(); ++m) { + int centroidIndex = Byte.toUnsignedInt(encoded.get(m)); + var index = m * clusterCount + centroidIndex; + sum += partialSums.get(index); + aMag += aMagnitude.get(index); + } + + return (float) (sum / Math.sqrt(aMag * bMagnitude)); + } } diff --git a/jvector-native/src/main/c/jvector_simd.c b/jvector-native/src/main/c/jvector_simd.c index 886186fa..c807f33c 100644 --- a/jvector-native/src/main/c/jvector_simd.c +++ b/jvector-native/src/main/c/jvector_simd.c @@ -319,6 +319,49 @@ float assemble_and_sum_f32_512(const float* data, int dataBase, const unsigned c return res; } +float decoded_cosine_similarity_f32_512(const unsigned char* baseOffsets, int baseOffsetsLength, int clusterCount, const float* partialSums, const float* aMagnitude, float bMagnitude) { + __m512 sum = _mm512_setzero_ps(); + __m512 vaMagnitude = _mm512_setzero_ps(); + int i = 0; + int limit = baseOffsetsLength - (baseOffsetsLength % 16); + __m512i indexRegister = initialIndexRegister; + __m512i scale = _mm512_set1_epi32(clusterCount); + + + for (; i < limit; i += 16) { + // Load and convert baseOffsets to integers + __m128i baseOffsetsRaw = _mm_loadu_si128((__m128i *)(baseOffsets + i)); + __m512i baseOffsetsInt = _mm512_cvtepu8_epi32(baseOffsetsRaw); + + indexRegister = _mm512_add_epi32(indexRegister, indexIncrement); + // Scale the baseOffsets by the cluster count + __m512i scaledOffsets = _mm512_mullo_epi32(indexRegister, scale); + + // Calculate the final convOffsets by adding the scaled indexes and the base offsets + __m512i convOffsets = _mm512_add_epi32(scaledOffsets, baseOffsetsInt); + + // Gather and sum values for partial sums and a magnitude + __m512 partialSumVals = _mm512_i32gather_ps(convOffsets, partialSums, 4); + sum = _mm512_add_ps(sum, partialSumVals); + + __m512 aMagnitudeVals = _mm512_i32gather_ps(convOffsets, aMagnitude, 4); + vaMagnitude = _mm512_add_ps(vaMagnitude, aMagnitudeVals); + } + + // Reduce sums + float sumResult = _mm512_reduce_add_ps(sum); + float aMagnitudeResult = _mm512_reduce_add_ps(vaMagnitude); + + // Handle the remaining elements + for (; i < baseOffsetsLength; i++) { + int offset = clusterCount * i + baseOffsets[i]; + sumResult += partialSums[offset]; + aMagnitudeResult += aMagnitude[offset]; + } + + return sumResult / sqrtf(aMagnitudeResult * bMagnitude); +} + void calculate_partial_sums_dot_f32_512(const float* codebook, int codebookIndex, int size, int clusterCount, const float* query, int queryOffset, float* partialSums) { int codebookBase = codebookIndex * clusterCount; for (int i = 0; i < clusterCount; i++) { diff --git a/jvector-native/src/main/c/jvector_simd.h b/jvector-native/src/main/c/jvector_simd.h index a5410ef5..1b96a0a8 100644 --- a/jvector-native/src/main/c/jvector_simd.h +++ b/jvector-native/src/main/c/jvector_simd.h @@ -29,6 +29,7 @@ void bulk_quantized_shuffle_dot_f32_512(const unsigned char* shuffles, int codeb void bulk_quantized_shuffle_euclidean_f32_512(const unsigned char* shuffles, int codebookCount, const char* quantizedPartials, float delta, float minDistance, float* results); void bulk_quantized_shuffle_cosine_f32_512(const unsigned char* shuffles, int codebookCount, const char* quantizedPartialSums, float sumDelta, float minDistance, const char* quantizedPartialMagnitudes, float magnitudeDelta, float minMagnitude, float queryMagnitudeSquared, float* results); float assemble_and_sum_f32_512(const float* data, int dataBase, const unsigned char* baseOffsets, int baseOffsetsLength); +float decoded_cosine_similarity_f32_512(const unsigned char* baseOffsets, int baseOffsetsLength, int clusterCount, const float* partialSums, const float* aMagnitude, float bMagnitude); void calculate_partial_sums_dot_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums); void calculate_partial_sums_euclidean_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums); void calculate_partial_sums_best_dot_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums, float* partialBestDistances); diff --git a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java index 3fe12e71..624ec27f 100644 --- a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java +++ b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java @@ -155,4 +155,10 @@ public void bulkShuffleQuantizedSimilarityCosine(ByteSequence shuffles, int c NativeSimdOps.bulk_quantized_shuffle_cosine_f32_512(((MemorySegmentByteSequence) shuffles).get(), codebookCount, ((MemorySegmentByteSequence) quantizedPartialSums).get(), sumDelta, minDistance, ((MemorySegmentByteSequence) quantizedPartialSquaredMagnitudes).get(), magnitudeDelta, minMagnitude, queryMagnitudeSquared, ((MemorySegmentVectorFloat) results).get()); } + + @Override + public float decodedCosineSimilarity(ByteSequence encoded, int clusterCount, VectorFloat partialSums, VectorFloat aMagnitude, float bMagnitude) + { + return NativeSimdOps.decoded_cosine_similarity_f32_512(((MemorySegmentByteSequence) encoded).get(), encoded.length(), clusterCount, ((MemorySegmentVectorFloat) partialSums).get(), ((MemorySegmentVectorFloat) aMagnitude).get(), bMagnitude); + } } diff --git a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/cnative/NativeSimdOps.java b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/cnative/NativeSimdOps.java index e148b1be..47d616ed 100644 --- a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/cnative/NativeSimdOps.java +++ b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/cnative/NativeSimdOps.java @@ -452,6 +452,58 @@ public static float assemble_and_sum_f32_512(MemorySegment data, int dataBase, M } } + private static class decoded_cosine_similarity_f32_512 { + public static final FunctionDescriptor DESC = FunctionDescriptor.of( + NativeSimdOps.C_FLOAT, + NativeSimdOps.C_POINTER, + NativeSimdOps.C_INT, + NativeSimdOps.C_INT, + NativeSimdOps.C_POINTER, + NativeSimdOps.C_POINTER, + NativeSimdOps.C_FLOAT + ); + + public static final MethodHandle HANDLE = Linker.nativeLinker().downcallHandle( + NativeSimdOps.findOrThrow("decoded_cosine_similarity_f32_512"), + DESC, Linker.Option.critical(true)); + } + + /** + * Function descriptor for: + * {@snippet lang=c : + * float decoded_cosine_similarity_f32_512(const unsigned char *baseOffsets, int baseOffsetsLength, int clusterCount, const float *partialSums, const float *aMagnitude, float bMagnitude) + * } + */ + public static FunctionDescriptor decoded_cosine_similarity_f32_512$descriptor() { + return decoded_cosine_similarity_f32_512.DESC; + } + + /** + * Downcall method handle for: + * {@snippet lang=c : + * float decoded_cosine_similarity_f32_512(const unsigned char *baseOffsets, int baseOffsetsLength, int clusterCount, const float *partialSums, const float *aMagnitude, float bMagnitude) + * } + */ + public static MethodHandle decoded_cosine_similarity_f32_512$handle() { + return decoded_cosine_similarity_f32_512.HANDLE; + } + /** + * {@snippet lang=c : + * float decoded_cosine_similarity_f32_512(const unsigned char *baseOffsets, int baseOffsetsLength, int clusterCount, const float *partialSums, const float *aMagnitude, float bMagnitude) + * } + */ + public static float decoded_cosine_similarity_f32_512(MemorySegment baseOffsets, int baseOffsetsLength, int clusterCount, MemorySegment partialSums, MemorySegment aMagnitude, float bMagnitude) { + var mh$ = decoded_cosine_similarity_f32_512.HANDLE; + try { + if (TRACE_DOWNCALLS) { + traceDowncall("decoded_cosine_similarity_f32_512", baseOffsets, baseOffsetsLength, clusterCount, partialSums, aMagnitude, bMagnitude); + } + return (float)mh$.invokeExact(baseOffsets, baseOffsetsLength, clusterCount, partialSums, aMagnitude, bMagnitude); + } catch (Throwable ex$) { + throw new AssertionError("should not reach here", ex$); + } + } + private static class calculate_partial_sums_dot_f32_512 { public static final FunctionDescriptor DESC = FunctionDescriptor.ofVoid( NativeSimdOps.C_POINTER, diff --git a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java index e0c2be5e..060ca9e1 100644 --- a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java +++ b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java @@ -92,11 +92,7 @@ public VectorFloat sub(VectorFloat a, int aOffset, VectorFloat b, int b @Override public float assembleAndSum(VectorFloat data, int dataBase, ByteSequence baseOffsets) { - float sum = 0f; - for (int i = 0; i < baseOffsets.length(); i++) { - sum += data.get(dataBase * i + Byte.toUnsignedInt(baseOffsets.get(i))); - } - return sum; + return SimdOps.assembleAndSum(((ArrayVectorFloat) data).get(), dataBase, ((ArrayByteSequence) baseOffsets).get()); } @Override @@ -159,5 +155,11 @@ public void calculatePartialSums(VectorFloat codebook, int codebookIndex, int public void quantizePartials(float delta, VectorFloat partials, VectorFloat partialBases, ByteSequence quantizedPartials) { SimdOps.quantizePartials(delta, (ArrayVectorFloat) partials, (ArrayVectorFloat) partialBases, (ArrayByteSequence) quantizedPartials); } + + @Override + public float decodedCosineSimilarity(ByteSequence encoded, int clusterCount, VectorFloat partialSums, VectorFloat aMagnitude, float bMagnitude) + { + return SimdOps.decodedCosineSimilarity((ArrayByteSequence) encoded, clusterCount, (ArrayVectorFloat) partialSums, (ArrayVectorFloat) aMagnitude, bMagnitude); + } } diff --git a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java index d5d132e8..a6a73f3d 100644 --- a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java +++ b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java @@ -525,10 +525,9 @@ static float assembleAndSum512(float[] data, int dataBase, byte[] baseOffsets) { FloatVector sum = FloatVector.zero(FloatVector.SPECIES_512); int i = 0; int limit = ByteVector.SPECIES_128.loopBound(baseOffsets.length); + var scale = IntVector.zero(IntVector.SPECIES_512).addIndex(dataBase); for (; i < limit; i += ByteVector.SPECIES_128.length()) { - var scale = IntVector.zero(IntVector.SPECIES_512).addIndex(1).add(i).mul(dataBase); - ByteVector.fromArray(ByteVector.SPECIES_128, baseOffsets, i) .convertShape(VectorOperators.B2I, IntVector.SPECIES_512, 0) .lanewise(VectorOperators.AND, BYTE_TO_INT_MASK_512) @@ -536,7 +535,8 @@ static float assembleAndSum512(float[] data, int dataBase, byte[] baseOffsets) { .add(scale) .intoArray(convOffsets,0); - sum = sum.add(FloatVector.fromArray(FloatVector.SPECIES_512, data, 0, convOffsets, 0)); + var offset = i * dataBase; + sum = sum.add(FloatVector.fromArray(FloatVector.SPECIES_512, data, offset, convOffsets, 0)); } float res = sum.reduceLanes(VectorOperators.ADD); @@ -553,9 +553,9 @@ static float assembleAndSum256(float[] data, int dataBase, byte[] baseOffsets) { FloatVector sum = FloatVector.zero(FloatVector.SPECIES_256); int i = 0; int limit = ByteVector.SPECIES_64.loopBound(baseOffsets.length); + var scale = IntVector.zero(IntVector.SPECIES_256).addIndex(dataBase); for (; i < limit; i += ByteVector.SPECIES_64.length()) { - var scale = IntVector.zero(IntVector.SPECIES_256).addIndex(1).add(i).mul(dataBase); ByteVector.fromArray(ByteVector.SPECIES_64, baseOffsets, i) .convertShape(VectorOperators.B2I, IntVector.SPECIES_256, 0) @@ -564,7 +564,8 @@ static float assembleAndSum256(float[] data, int dataBase, byte[] baseOffsets) { .add(scale) .intoArray(convOffsets,0); - sum = sum.add(FloatVector.fromArray(FloatVector.SPECIES_256, data, 0, convOffsets, 0)); + var offset = i * dataBase; + sum = sum.add(FloatVector.fromArray(FloatVector.SPECIES_256, data, offset, convOffsets, 0)); } float res = sum.reduceLanes(VectorOperators.ADD); @@ -658,4 +659,88 @@ public static void quantizePartials(float delta, ArrayVectorFloat partials, Arra } } } + + public static float decodedCosineSimilarity(ArrayByteSequence encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) { + return HAS_AVX512 + ? decodedCosineSimilarity512(encoded, clusterCount, partialSums, aMagnitude, bMagnitude) + : decodedCosineSimilarity256(encoded, clusterCount, partialSums, aMagnitude, bMagnitude); + } + + public static float decodedCosineSimilarity512(ArrayByteSequence encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) { + var sum = FloatVector.zero(FloatVector.SPECIES_512); + var vaMagnitude = FloatVector.zero(FloatVector.SPECIES_512); + var baseOffsets = encoded.get(); + var partialSumsArray = partialSums.get(); + var aMagnitudeArray = aMagnitude.get(); + + int[] convOffsets = scratchInt512.get(); + int i = 0; + int limit = ByteVector.SPECIES_128.loopBound(baseOffsets.length); + + var scale = IntVector.zero(IntVector.SPECIES_512).addIndex(clusterCount); + + for (; i < limit; i += ByteVector.SPECIES_128.length()) { + + ByteVector.fromArray(ByteVector.SPECIES_128, baseOffsets, i) + .convertShape(VectorOperators.B2I, IntVector.SPECIES_512, 0) + .lanewise(VectorOperators.AND, BYTE_TO_INT_MASK_512) + .reinterpretAsInts() + .add(scale) + .intoArray(convOffsets,0); + + var offset = i * clusterCount; + sum = sum.add(FloatVector.fromArray(FloatVector.SPECIES_512, partialSumsArray, offset, convOffsets, 0)); + vaMagnitude = vaMagnitude.add(FloatVector.fromArray(FloatVector.SPECIES_512, aMagnitudeArray, offset, convOffsets, 0)); + } + + float sumResult = sum.reduceLanes(VectorOperators.ADD); + float aMagnitudeResult = vaMagnitude.reduceLanes(VectorOperators.ADD); + + for (; i < baseOffsets.length; i++) { + int offset = clusterCount * i + Byte.toUnsignedInt(baseOffsets[i]); + sumResult += partialSumsArray[offset]; + aMagnitudeResult += aMagnitudeArray[offset]; + } + + return (float) (sumResult / Math.sqrt(aMagnitudeResult * bMagnitude)); + } + + public static float decodedCosineSimilarity256(ArrayByteSequence encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) { + var sum = FloatVector.zero(FloatVector.SPECIES_256); + var vaMagnitude = FloatVector.zero(FloatVector.SPECIES_256); + var baseOffsets = encoded.get(); + var partialSumsArray = partialSums.get(); + var aMagnitudeArray = aMagnitude.get(); + + int[] convOffsets = scratchInt256.get(); + int i = 0; + int limit = ByteVector.SPECIES_64.loopBound(baseOffsets.length); + + var scale = IntVector.zero(IntVector.SPECIES_256).addIndex(clusterCount); + + for (; i < limit; i += ByteVector.SPECIES_64.length()) { + + ByteVector.fromArray(ByteVector.SPECIES_64, baseOffsets, i) + .convertShape(VectorOperators.B2I, IntVector.SPECIES_256, 0) + .lanewise(VectorOperators.AND, BYTE_TO_INT_MASK_256) + .reinterpretAsInts() + .add(scale) + .intoArray(convOffsets,0); + + var offset = i * clusterCount; + sum = sum.add(FloatVector.fromArray(FloatVector.SPECIES_256, partialSumsArray, offset, convOffsets, 0)); + vaMagnitude = vaMagnitude.add(FloatVector.fromArray(FloatVector.SPECIES_256, aMagnitudeArray, offset, convOffsets, 0)); + } + + float sumResult = sum.reduceLanes(VectorOperators.ADD); + float aMagnitudeResult = vaMagnitude.reduceLanes(VectorOperators.ADD); + + for (; i < baseOffsets.length; i++) { + int offset = clusterCount * i + Byte.toUnsignedInt(baseOffsets[i]); + sumResult += partialSumsArray[offset]; + aMagnitudeResult += aMagnitudeArray[offset]; + } + + return (float) (sumResult / Math.sqrt(aMagnitudeResult * bMagnitude)); + } }