From f6eca79d5d99fa9d9193a1dcb2001e946f7adbe7 Mon Sep 17 00:00:00 2001 From: Joel Knighton Date: Tue, 12 Nov 2024 16:36:10 -0600 Subject: [PATCH 1/6] WIP --- .../github/jbellis/jvector/pq/PQDecoder.java | 10 +---- .../jbellis/jvector/vector/VectorUtil.java | 4 ++ .../jvector/vector/VectorUtilSupport.java | 14 +++++++ .../vector/PanamaVectorUtilSupport.java | 6 +++ .../jbellis/jvector/vector/SimdOps.java | 39 +++++++++++++++++++ 5 files changed, 64 insertions(+), 9 deletions(-) diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/PQDecoder.java b/jvector-base/src/main/java/io/github/jbellis/jvector/pq/PQDecoder.java index 0be0e088..e417d50e 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/PQDecoder.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/pq/PQDecoder.java @@ -131,18 +131,10 @@ public float similarityTo(int node2) { } protected float decodedCosine(int node2) { - float sum = 0.0f; - float aMag = 0.0f; ByteSequence encoded = cv.get(node2); - for (int m = 0; m < encoded.length(); ++m) { - int centroidIndex = Byte.toUnsignedInt(encoded.get(m)); - sum += partialSums.get((m * cv.pq.getClusterCount()) + centroidIndex); - aMag += aMagnitude.get((m * cv.pq.getClusterCount()) + centroidIndex); - } - - return (float) (sum / Math.sqrt(aMag * bMagnitude)); + return VectorUtil.decodedCosineSimilarity(encoded, cv.pq.getClusterCount(), partialSums, aMagnitude, bMagnitude); } } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java index a6d87807..d860cc1b 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java @@ -194,4 +194,8 @@ public static float max(VectorFloat v) { public static float min(VectorFloat v) { return impl.min(v); } + + public static float decodedCosineSimilarity(ByteSequence encoded, int clusterCount, VectorFloat partialSums, VectorFloat aMagnitude, float bMagnitude) { + return impl.decodedCosineSimilarity(encoded, clusterCount, partialSums, aMagnitude, bMagnitude); + } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java index 46cb4f18..e3e2953e 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java @@ -199,4 +199,18 @@ default void bulkShuffleQuantizedSimilarityCosine(ByteSequence shuffles, int float max(VectorFloat v); float min(VectorFloat v); + default float decodedCosineSimilarity(ByteSequence encoded, int clusterCount, VectorFloat partialSums, VectorFloat aMagnitude, float bMagnitude) + { + float sum = 0.0f; + float aMag = 0.0f; + + for (int m = 0; m < encoded.length(); ++m) { + int centroidIndex = Byte.toUnsignedInt(encoded.get(m)); + var index = m * clusterCount + centroidIndex; + sum += partialSums.get(index); + aMag += aMagnitude.get(index); + } + + return (float) (sum / Math.sqrt(aMag * bMagnitude)); + } } diff --git a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java index e0c2be5e..e71d780a 100644 --- a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java +++ b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java @@ -159,5 +159,11 @@ public void calculatePartialSums(VectorFloat codebook, int codebookIndex, int public void quantizePartials(float delta, VectorFloat partials, VectorFloat partialBases, ByteSequence quantizedPartials) { SimdOps.quantizePartials(delta, (ArrayVectorFloat) partials, (ArrayVectorFloat) partialBases, (ArrayByteSequence) quantizedPartials); } + + @Override + public float decodedCosineSimilarity(ByteSequence encoded, int clusterCount, VectorFloat partialSums, VectorFloat aMagnitude, float bMagnitude) + { + return SimdOps.decodedCosineSimilarity((ArrayByteSequence) encoded, clusterCount, (ArrayVectorFloat) partialSums, (ArrayVectorFloat) aMagnitude, bMagnitude); + } } diff --git a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java index d5d132e8..b16d82b8 100644 --- a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java +++ b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java @@ -658,4 +658,43 @@ public static void quantizePartials(float delta, ArrayVectorFloat partials, Arra } } } + + public static float decodedCosineSimilarity(ArrayByteSequence encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) { + var sum = FloatVector.zero(FloatVector.SPECIES_512); + var vaMagnitude = FloatVector.zero(FloatVector.SPECIES_512); + var baseOffsets = encoded.get(); + var partialSumsArray = partialSums.get(); + var aMagnitudeArray = aMagnitude.get(); + + int[] convOffsets = scratchInt512.get(); + int i = 0; + int limit = ByteVector.SPECIES_128.loopBound(baseOffsets.length); + + var scale = IntVector.zero(IntVector.SPECIES_512).addIndex(clusterCount); + + for (; i < limit; i += ByteVector.SPECIES_128.length()) { + + ByteVector.fromArray(ByteVector.SPECIES_128, baseOffsets, i) + .convertShape(VectorOperators.B2I, IntVector.SPECIES_512, 0) + .lanewise(VectorOperators.AND, BYTE_TO_INT_MASK_512) + .reinterpretAsInts() + .add(scale) + .intoArray(convOffsets,0); + + var offset = i * clusterCount; + sum = sum.add(FloatVector.fromArray(FloatVector.SPECIES_512, partialSumsArray, offset, convOffsets, 0)); + vaMagnitude = vaMagnitude.add(FloatVector.fromArray(FloatVector.SPECIES_512, aMagnitudeArray, offset, convOffsets, 0)); + } + + float sumResult = sum.reduceLanes(VectorOperators.ADD); + float aMagnitudeResult = vaMagnitude.reduceLanes(VectorOperators.ADD); + + for (; i < baseOffsets.length; i++) { + int offset = clusterCount * i + Byte.toUnsignedInt(baseOffsets[i]); + sumResult += partialSumsArray[offset]; + aMagnitudeResult += aMagnitudeArray[offset]; + } + + return (float) (sumResult / Math.sqrt(aMagnitudeResult * bMagnitude)); + } } From 20ed4ac99bb0ea242618cd5399912338c496a9e0 Mon Sep 17 00:00:00 2001 From: Michael Marshall Date: Thu, 14 Nov 2024 13:39:43 -0600 Subject: [PATCH 2/6] Break decodedCosineSimilarity out by HAS_AVX512 --- .../jbellis/jvector/vector/SimdOps.java | 45 +++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java index b16d82b8..55ef28f3 100644 --- a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java +++ b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java @@ -660,6 +660,12 @@ public static void quantizePartials(float delta, ArrayVectorFloat partials, Arra } public static float decodedCosineSimilarity(ArrayByteSequence encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) { + return HAS_AVX512 + ? decodedCosineSimilarity512(encoded, clusterCount, partialSums, aMagnitude, bMagnitude) + : decodedCosineSimilarity256(encoded, clusterCount, partialSums, aMagnitude, bMagnitude); + } + + public static float decodedCosineSimilarity512(ArrayByteSequence encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) { var sum = FloatVector.zero(FloatVector.SPECIES_512); var vaMagnitude = FloatVector.zero(FloatVector.SPECIES_512); var baseOffsets = encoded.get(); @@ -697,4 +703,43 @@ public static float decodedCosineSimilarity(ArrayByteSequence encoded, int clust return (float) (sumResult / Math.sqrt(aMagnitudeResult * bMagnitude)); } + + public static float decodedCosineSimilarity256(ArrayByteSequence encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) { + var sum = FloatVector.zero(FloatVector.SPECIES_256); + var vaMagnitude = FloatVector.zero(FloatVector.SPECIES_256); + var baseOffsets = encoded.get(); + var partialSumsArray = partialSums.get(); + var aMagnitudeArray = aMagnitude.get(); + + int[] convOffsets = scratchInt256.get(); + int i = 0; + int limit = ByteVector.SPECIES_64.loopBound(baseOffsets.length); + + var scale = IntVector.zero(IntVector.SPECIES_256).addIndex(clusterCount); + + for (; i < limit; i += ByteVector.SPECIES_64.length()) { + + ByteVector.fromArray(ByteVector.SPECIES_64, baseOffsets, i) + .convertShape(VectorOperators.B2I, IntVector.SPECIES_256, 0) + .lanewise(VectorOperators.AND, BYTE_TO_INT_MASK_256) + .reinterpretAsInts() + .add(scale) + .intoArray(convOffsets,0); + + var offset = i * clusterCount; + sum = sum.add(FloatVector.fromArray(FloatVector.SPECIES_256, partialSumsArray, offset, convOffsets, 0)); + vaMagnitude = vaMagnitude.add(FloatVector.fromArray(FloatVector.SPECIES_256, aMagnitudeArray, offset, convOffsets, 0)); + } + + float sumResult = sum.reduceLanes(VectorOperators.ADD); + float aMagnitudeResult = vaMagnitude.reduceLanes(VectorOperators.ADD); + + for (; i < baseOffsets.length; i++) { + int offset = clusterCount * i + Byte.toUnsignedInt(baseOffsets[i]); + sumResult += partialSumsArray[offset]; + aMagnitudeResult += aMagnitudeArray[offset]; + } + + return (float) (sumResult / Math.sqrt(aMagnitudeResult * bMagnitude)); + } } From bebb4f6f60d7cac9a6073922652f7a6e96df41d6 Mon Sep 17 00:00:00 2001 From: Michael Marshall Date: Thu, 14 Nov 2024 15:08:26 -0600 Subject: [PATCH 3/6] Make PVUS#assembleAndSum use SimdOps; optimize SimdOps assembleAndSum --- .../jbellis/jvector/vector/PanamaVectorUtilSupport.java | 6 +----- .../java/io/github/jbellis/jvector/vector/SimdOps.java | 9 ++++----- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java index e71d780a..060ca9e1 100644 --- a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java +++ b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java @@ -92,11 +92,7 @@ public VectorFloat sub(VectorFloat a, int aOffset, VectorFloat b, int b @Override public float assembleAndSum(VectorFloat data, int dataBase, ByteSequence baseOffsets) { - float sum = 0f; - for (int i = 0; i < baseOffsets.length(); i++) { - sum += data.get(dataBase * i + Byte.toUnsignedInt(baseOffsets.get(i))); - } - return sum; + return SimdOps.assembleAndSum(((ArrayVectorFloat) data).get(), dataBase, ((ArrayByteSequence) baseOffsets).get()); } @Override diff --git a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java index 55ef28f3..60c5170b 100644 --- a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java +++ b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java @@ -525,11 +525,10 @@ static float assembleAndSum512(float[] data, int dataBase, byte[] baseOffsets) { FloatVector sum = FloatVector.zero(FloatVector.SPECIES_512); int i = 0; int limit = ByteVector.SPECIES_128.loopBound(baseOffsets.length); + var scale = IntVector.zero(IntVector.SPECIES_512).addIndex(dataBase); for (; i < limit; i += ByteVector.SPECIES_128.length()) { - var scale = IntVector.zero(IntVector.SPECIES_512).addIndex(1).add(i).mul(dataBase); - - ByteVector.fromArray(ByteVector.SPECIES_128, baseOffsets, i) + ByteVector.fromArray(ByteVector.SPECIES_128, baseOffsets, i * dataBase) .convertShape(VectorOperators.B2I, IntVector.SPECIES_512, 0) .lanewise(VectorOperators.AND, BYTE_TO_INT_MASK_512) .reinterpretAsInts() @@ -553,11 +552,11 @@ static float assembleAndSum256(float[] data, int dataBase, byte[] baseOffsets) { FloatVector sum = FloatVector.zero(FloatVector.SPECIES_256); int i = 0; int limit = ByteVector.SPECIES_64.loopBound(baseOffsets.length); + var scale = IntVector.zero(IntVector.SPECIES_256).addIndex(dataBase); for (; i < limit; i += ByteVector.SPECIES_64.length()) { - var scale = IntVector.zero(IntVector.SPECIES_256).addIndex(1).add(i).mul(dataBase); - ByteVector.fromArray(ByteVector.SPECIES_64, baseOffsets, i) + ByteVector.fromArray(ByteVector.SPECIES_64, baseOffsets, i * dataBase) .convertShape(VectorOperators.B2I, IntVector.SPECIES_256, 0) .lanewise(VectorOperators.AND, BYTE_TO_INT_MASK_256) .reinterpretAsInts() From feb1e02b3c068e5e294513903dfb9e3167cc24eb Mon Sep 17 00:00:00 2001 From: Michael Marshall Date: Thu, 14 Nov 2024 16:18:47 -0600 Subject: [PATCH 4/6] Attempt native implemenation --- jvector-native/src/main/c/jvector_simd.c | 45 ++++++++++++++++++++++++ jvector-native/src/main/c/jvector_simd.h | 1 + 2 files changed, 46 insertions(+) diff --git a/jvector-native/src/main/c/jvector_simd.c b/jvector-native/src/main/c/jvector_simd.c index 886186fa..22422502 100644 --- a/jvector-native/src/main/c/jvector_simd.c +++ b/jvector-native/src/main/c/jvector_simd.c @@ -318,6 +318,51 @@ float assemble_and_sum_f32_512(const float* data, int dataBase, const unsigned c return res; } +float decoded_cosine_similarity_f32_512(const unsigned char* baseOffsets, int baseOffsetsLength, int clusterCount, const float* partialSums, const float* aMagnitude, float bMagnitude) { + __m512 sum = _mm512_setzero_ps(); + __m512 vaMagnitude = _mm512_setzero_ps(); + int i = 0; + int limit = baseOffsetsLength - (baseOffsetsLength % 16); + __m512i indexRegister = initialIndexRegister; + __m512i scale = _mm512_set1_epi32(clusterCount); + + + for (; i < limit; i += 16) { + // Load and convert baseOffsets to integers + __m128i baseOffsetsRaw = _mm_loadu_si128((__m128i *)(baseOffsets + i)); + __m512i baseOffsetsInt = _mm512_cvtepu8_epi32(baseOffsetsRaw); + + indexRegister = _mm512_add_epi32(indexRegister, indexIncrement); + // Scale the baseOffsets by the cluster count + __m512i scaledOffsets = _mm512_mullo_epi32(indexRegister, scale); + + // Compute the offset base by multiplying 'i' with clusterCount and broadcasting to all lanes + __m512i offsetBase = _mm512_set1_epi32(i * clusterCount); + + // Calculate the final convOffsets by adding the scaled offsets and the offset base + __m512i convOffsets = _mm512_add_epi32(scaledOffsets, offsetBase); + + // Gather and sum values for partial sums and a magnitude + __m512 partialSumVals = _mm512_i32gather_ps(convOffsets, partialSums, 4); + sum = _mm512_add_ps(sum, partialSumVals); + + __m512 aMagnitudeVals = _mm512_i32gather_ps(convOffsets, aMagnitude, 4); + vaMagnitude = _mm512_add_ps(vaMagnitude, aMagnitudeVals); + } + + // Reduce sums + float sumResult = _mm512_reduce_add_ps(sum); + float aMagnitudeResult = _mm512_reduce_add_ps(vaMagnitude); + + // Handle the remaining elements + for (; i < baseOffsetsLength; i++) { + int offset = clusterCount * i + baseOffsets[i]; + sumResult += partialSums[offset]; + aMagnitudeResult += aMagnitude[offset]; + } + + return sumResult / sqrtf(aMagnitudeResult * bMagnitude); +} void calculate_partial_sums_dot_f32_512(const float* codebook, int codebookIndex, int size, int clusterCount, const float* query, int queryOffset, float* partialSums) { int codebookBase = codebookIndex * clusterCount; diff --git a/jvector-native/src/main/c/jvector_simd.h b/jvector-native/src/main/c/jvector_simd.h index a5410ef5..1b96a0a8 100644 --- a/jvector-native/src/main/c/jvector_simd.h +++ b/jvector-native/src/main/c/jvector_simd.h @@ -29,6 +29,7 @@ void bulk_quantized_shuffle_dot_f32_512(const unsigned char* shuffles, int codeb void bulk_quantized_shuffle_euclidean_f32_512(const unsigned char* shuffles, int codebookCount, const char* quantizedPartials, float delta, float minDistance, float* results); void bulk_quantized_shuffle_cosine_f32_512(const unsigned char* shuffles, int codebookCount, const char* quantizedPartialSums, float sumDelta, float minDistance, const char* quantizedPartialMagnitudes, float magnitudeDelta, float minMagnitude, float queryMagnitudeSquared, float* results); float assemble_and_sum_f32_512(const float* data, int dataBase, const unsigned char* baseOffsets, int baseOffsetsLength); +float decoded_cosine_similarity_f32_512(const unsigned char* baseOffsets, int baseOffsetsLength, int clusterCount, const float* partialSums, const float* aMagnitude, float bMagnitude); void calculate_partial_sums_dot_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums); void calculate_partial_sums_euclidean_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums); void calculate_partial_sums_best_dot_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums, float* partialBestDistances); From aee061ea656b795d1f8a6d47f9dd3ac4a0129c59 Mon Sep 17 00:00:00 2001 From: Joel Knighton Date: Thu, 14 Nov 2024 16:38:19 -0600 Subject: [PATCH 5/6] Fix assembleAndSum --- .../java/io/github/jbellis/jvector/vector/SimdOps.java | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java index 60c5170b..a6a73f3d 100644 --- a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java +++ b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java @@ -528,14 +528,15 @@ static float assembleAndSum512(float[] data, int dataBase, byte[] baseOffsets) { var scale = IntVector.zero(IntVector.SPECIES_512).addIndex(dataBase); for (; i < limit; i += ByteVector.SPECIES_128.length()) { - ByteVector.fromArray(ByteVector.SPECIES_128, baseOffsets, i * dataBase) + ByteVector.fromArray(ByteVector.SPECIES_128, baseOffsets, i) .convertShape(VectorOperators.B2I, IntVector.SPECIES_512, 0) .lanewise(VectorOperators.AND, BYTE_TO_INT_MASK_512) .reinterpretAsInts() .add(scale) .intoArray(convOffsets,0); - sum = sum.add(FloatVector.fromArray(FloatVector.SPECIES_512, data, 0, convOffsets, 0)); + var offset = i * dataBase; + sum = sum.add(FloatVector.fromArray(FloatVector.SPECIES_512, data, offset, convOffsets, 0)); } float res = sum.reduceLanes(VectorOperators.ADD); @@ -556,14 +557,15 @@ static float assembleAndSum256(float[] data, int dataBase, byte[] baseOffsets) { for (; i < limit; i += ByteVector.SPECIES_64.length()) { - ByteVector.fromArray(ByteVector.SPECIES_64, baseOffsets, i * dataBase) + ByteVector.fromArray(ByteVector.SPECIES_64, baseOffsets, i) .convertShape(VectorOperators.B2I, IntVector.SPECIES_256, 0) .lanewise(VectorOperators.AND, BYTE_TO_INT_MASK_256) .reinterpretAsInts() .add(scale) .intoArray(convOffsets,0); - sum = sum.add(FloatVector.fromArray(FloatVector.SPECIES_256, data, 0, convOffsets, 0)); + var offset = i * dataBase; + sum = sum.add(FloatVector.fromArray(FloatVector.SPECIES_256, data, offset, convOffsets, 0)); } float res = sum.reduceLanes(VectorOperators.ADD); From 3d79217e01d462752bff8f67bf6843c6308a6cd2 Mon Sep 17 00:00:00 2001 From: Joel Knighton Date: Mon, 18 Nov 2024 17:47:18 -0600 Subject: [PATCH 6/6] Fix decoded_cosine_similarity_f32_512. Generate bindings with jextract. Call binding from NativeVectorUtilSupport --- jvector-native/src/main/c/jvector_simd.c | 8 ++- .../vector/NativeVectorUtilSupport.java | 6 +++ .../jvector/vector/cnative/NativeSimdOps.java | 52 +++++++++++++++++++ 3 files changed, 61 insertions(+), 5 deletions(-) diff --git a/jvector-native/src/main/c/jvector_simd.c b/jvector-native/src/main/c/jvector_simd.c index 22422502..c807f33c 100644 --- a/jvector-native/src/main/c/jvector_simd.c +++ b/jvector-native/src/main/c/jvector_simd.c @@ -318,6 +318,7 @@ float assemble_and_sum_f32_512(const float* data, int dataBase, const unsigned c return res; } + float decoded_cosine_similarity_f32_512(const unsigned char* baseOffsets, int baseOffsetsLength, int clusterCount, const float* partialSums, const float* aMagnitude, float bMagnitude) { __m512 sum = _mm512_setzero_ps(); __m512 vaMagnitude = _mm512_setzero_ps(); @@ -336,11 +337,8 @@ float decoded_cosine_similarity_f32_512(const unsigned char* baseOffsets, int ba // Scale the baseOffsets by the cluster count __m512i scaledOffsets = _mm512_mullo_epi32(indexRegister, scale); - // Compute the offset base by multiplying 'i' with clusterCount and broadcasting to all lanes - __m512i offsetBase = _mm512_set1_epi32(i * clusterCount); - - // Calculate the final convOffsets by adding the scaled offsets and the offset base - __m512i convOffsets = _mm512_add_epi32(scaledOffsets, offsetBase); + // Calculate the final convOffsets by adding the scaled indexes and the base offsets + __m512i convOffsets = _mm512_add_epi32(scaledOffsets, baseOffsetsInt); // Gather and sum values for partial sums and a magnitude __m512 partialSumVals = _mm512_i32gather_ps(convOffsets, partialSums, 4); diff --git a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java index 3fe12e71..624ec27f 100644 --- a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java +++ b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java @@ -155,4 +155,10 @@ public void bulkShuffleQuantizedSimilarityCosine(ByteSequence shuffles, int c NativeSimdOps.bulk_quantized_shuffle_cosine_f32_512(((MemorySegmentByteSequence) shuffles).get(), codebookCount, ((MemorySegmentByteSequence) quantizedPartialSums).get(), sumDelta, minDistance, ((MemorySegmentByteSequence) quantizedPartialSquaredMagnitudes).get(), magnitudeDelta, minMagnitude, queryMagnitudeSquared, ((MemorySegmentVectorFloat) results).get()); } + + @Override + public float decodedCosineSimilarity(ByteSequence encoded, int clusterCount, VectorFloat partialSums, VectorFloat aMagnitude, float bMagnitude) + { + return NativeSimdOps.decoded_cosine_similarity_f32_512(((MemorySegmentByteSequence) encoded).get(), encoded.length(), clusterCount, ((MemorySegmentVectorFloat) partialSums).get(), ((MemorySegmentVectorFloat) aMagnitude).get(), bMagnitude); + } } diff --git a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/cnative/NativeSimdOps.java b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/cnative/NativeSimdOps.java index e148b1be..47d616ed 100644 --- a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/cnative/NativeSimdOps.java +++ b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/cnative/NativeSimdOps.java @@ -452,6 +452,58 @@ public static float assemble_and_sum_f32_512(MemorySegment data, int dataBase, M } } + private static class decoded_cosine_similarity_f32_512 { + public static final FunctionDescriptor DESC = FunctionDescriptor.of( + NativeSimdOps.C_FLOAT, + NativeSimdOps.C_POINTER, + NativeSimdOps.C_INT, + NativeSimdOps.C_INT, + NativeSimdOps.C_POINTER, + NativeSimdOps.C_POINTER, + NativeSimdOps.C_FLOAT + ); + + public static final MethodHandle HANDLE = Linker.nativeLinker().downcallHandle( + NativeSimdOps.findOrThrow("decoded_cosine_similarity_f32_512"), + DESC, Linker.Option.critical(true)); + } + + /** + * Function descriptor for: + * {@snippet lang=c : + * float decoded_cosine_similarity_f32_512(const unsigned char *baseOffsets, int baseOffsetsLength, int clusterCount, const float *partialSums, const float *aMagnitude, float bMagnitude) + * } + */ + public static FunctionDescriptor decoded_cosine_similarity_f32_512$descriptor() { + return decoded_cosine_similarity_f32_512.DESC; + } + + /** + * Downcall method handle for: + * {@snippet lang=c : + * float decoded_cosine_similarity_f32_512(const unsigned char *baseOffsets, int baseOffsetsLength, int clusterCount, const float *partialSums, const float *aMagnitude, float bMagnitude) + * } + */ + public static MethodHandle decoded_cosine_similarity_f32_512$handle() { + return decoded_cosine_similarity_f32_512.HANDLE; + } + /** + * {@snippet lang=c : + * float decoded_cosine_similarity_f32_512(const unsigned char *baseOffsets, int baseOffsetsLength, int clusterCount, const float *partialSums, const float *aMagnitude, float bMagnitude) + * } + */ + public static float decoded_cosine_similarity_f32_512(MemorySegment baseOffsets, int baseOffsetsLength, int clusterCount, MemorySegment partialSums, MemorySegment aMagnitude, float bMagnitude) { + var mh$ = decoded_cosine_similarity_f32_512.HANDLE; + try { + if (TRACE_DOWNCALLS) { + traceDowncall("decoded_cosine_similarity_f32_512", baseOffsets, baseOffsetsLength, clusterCount, partialSums, aMagnitude, bMagnitude); + } + return (float)mh$.invokeExact(baseOffsets, baseOffsetsLength, clusterCount, partialSums, aMagnitude, bMagnitude); + } catch (Throwable ex$) { + throw new AssertionError("should not reach here", ex$); + } + } + private static class calculate_partial_sums_dot_f32_512 { public static final FunctionDescriptor DESC = FunctionDescriptor.ofVoid( NativeSimdOps.C_POINTER,