Skip to content

Commit

Permalink
Attempt native implemenation
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeljmarshall committed Nov 14, 2024
1 parent bebb4f6 commit b43d85c
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 0 deletions.
35 changes: 35 additions & 0 deletions jvector-native/src/main/c/jvector_simd.c
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,41 @@ float assemble_and_sum_f32_512(const float* data, int dataBase, const unsigned c
return res;
}

float decode_cosine_similarity_f32_512(const unsigned char* baseOffsets, int baseOffsetsLength, int clusterCount, const float* partialSums, const float* aMagnitude, float bMagnitude) {
__m512 sum = _mm512_setzero_ps();
__m512 vaMagnitude = _mm512_setzero_ps();
int i = 0;
int limit = baseOffsetsLength - (baseOffsetsLength % 16);
__m512i indexRegister = initialIndexRegister;
__m512i dataBaseVec = _mm512_set1_epi32(clusterCount);

for (; i < limit; i += 16) {
__m128i baseOffsetsRaw = _mm_loadu_si128((__m128i *)(baseOffsets + i));
__m512i baseOffsetsInt = _mm512_cvtepu8_epi32(baseOffsetsRaw);
// we have base offsets int, which we need to scale to index into data.
// first, we want to initialize a vector with the lane number added as an index
indexRegister = _mm512_add_epi32(indexRegister, indexIncrement);
// then we want to multiply by dataBase
__m512i scale = _mm512_mullo_epi32(indexRegister, dataBaseVec);
// then we want to add the base offsets
__m512i convOffsets = _mm512_add_epi32(scale, baseOffsetsInt);

__m512 partials = _mm512_i32gather_ps(convOffsets, data, 4);
sum = _mm512_add_ps(sum, partials);
vaMagnitude = _mm512_add_ps(vaMagnitude, aMagnitude);
}

float sumResult = _mm512_reduce_add_ps(sum);
float aMagnitudeResult = _mm512_reduce_add_ps(vaMagnitude);
for (; i < baseOffsetsLength; i++) {
int offset = clusterCount * i + baseOffsets[i];
sumResult += partialSums[offset];
aMagnitudeResult += aMagnitude[offset];
}

return sumResult / sqrt(aMagnitudeResult * bMagnitude);
}

void calculate_partial_sums_dot_f32_512(const float* codebook, int codebookIndex, int size, int clusterCount, const float* query, int queryOffset, float* partialSums) {
int codebookBase = codebookIndex * clusterCount;
for (int i = 0; i < clusterCount; i++) {
Expand Down
1 change: 1 addition & 0 deletions jvector-native/src/main/c/jvector_simd.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ void bulk_quantized_shuffle_dot_f32_512(const unsigned char* shuffles, int codeb
void bulk_quantized_shuffle_euclidean_f32_512(const unsigned char* shuffles, int codebookCount, const char* quantizedPartials, float delta, float minDistance, float* results);
void bulk_quantized_shuffle_cosine_f32_512(const unsigned char* shuffles, int codebookCount, const char* quantizedPartialSums, float sumDelta, float minDistance, const char* quantizedPartialMagnitudes, float magnitudeDelta, float minMagnitude, float queryMagnitudeSquared, float* results);
float assemble_and_sum_f32_512(const float* data, int dataBase, const unsigned char* baseOffsets, int baseOffsetsLength);
float decode_cosine_similarity_f32_512(const unsigned char* baseOffsets, int baseOffsetsLength, int clusterCount, const float* partialSums, const float* aMagnitude, float bMagnitude);
void calculate_partial_sums_dot_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums);
void calculate_partial_sums_euclidean_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums);
void calculate_partial_sums_best_dot_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums, float* partialBestDistances);
Expand Down

0 comments on commit b43d85c

Please sign in to comment.