Skip to content

Commit

Permalink
Attempt native implemenation
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeljmarshall committed Nov 14, 2024
1 parent bebb4f6 commit feb1e02
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 0 deletions.
45 changes: 45 additions & 0 deletions jvector-native/src/main/c/jvector_simd.c
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,51 @@ float assemble_and_sum_f32_512(const float* data, int dataBase, const unsigned c

return res;
}
float decoded_cosine_similarity_f32_512(const unsigned char* baseOffsets, int baseOffsetsLength, int clusterCount, const float* partialSums, const float* aMagnitude, float bMagnitude) {
__m512 sum = _mm512_setzero_ps();
__m512 vaMagnitude = _mm512_setzero_ps();
int i = 0;
int limit = baseOffsetsLength - (baseOffsetsLength % 16);
__m512i indexRegister = initialIndexRegister;
__m512i scale = _mm512_set1_epi32(clusterCount);


for (; i < limit; i += 16) {
// Load and convert baseOffsets to integers
__m128i baseOffsetsRaw = _mm_loadu_si128((__m128i *)(baseOffsets + i));
__m512i baseOffsetsInt = _mm512_cvtepu8_epi32(baseOffsetsRaw);

indexRegister = _mm512_add_epi32(indexRegister, indexIncrement);
// Scale the baseOffsets by the cluster count
__m512i scaledOffsets = _mm512_mullo_epi32(indexRegister, scale);

// Compute the offset base by multiplying 'i' with clusterCount and broadcasting to all lanes
__m512i offsetBase = _mm512_set1_epi32(i * clusterCount);

// Calculate the final convOffsets by adding the scaled offsets and the offset base
__m512i convOffsets = _mm512_add_epi32(scaledOffsets, offsetBase);

// Gather and sum values for partial sums and a magnitude
__m512 partialSumVals = _mm512_i32gather_ps(convOffsets, partialSums, 4);
sum = _mm512_add_ps(sum, partialSumVals);

__m512 aMagnitudeVals = _mm512_i32gather_ps(convOffsets, aMagnitude, 4);
vaMagnitude = _mm512_add_ps(vaMagnitude, aMagnitudeVals);
}

// Reduce sums
float sumResult = _mm512_reduce_add_ps(sum);
float aMagnitudeResult = _mm512_reduce_add_ps(vaMagnitude);

// Handle the remaining elements
for (; i < baseOffsetsLength; i++) {
int offset = clusterCount * i + baseOffsets[i];
sumResult += partialSums[offset];
aMagnitudeResult += aMagnitude[offset];
}

return sumResult / sqrtf(aMagnitudeResult * bMagnitude);
}

void calculate_partial_sums_dot_f32_512(const float* codebook, int codebookIndex, int size, int clusterCount, const float* query, int queryOffset, float* partialSums) {
int codebookBase = codebookIndex * clusterCount;
Expand Down
1 change: 1 addition & 0 deletions jvector-native/src/main/c/jvector_simd.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ void bulk_quantized_shuffle_dot_f32_512(const unsigned char* shuffles, int codeb
void bulk_quantized_shuffle_euclidean_f32_512(const unsigned char* shuffles, int codebookCount, const char* quantizedPartials, float delta, float minDistance, float* results);
void bulk_quantized_shuffle_cosine_f32_512(const unsigned char* shuffles, int codebookCount, const char* quantizedPartialSums, float sumDelta, float minDistance, const char* quantizedPartialMagnitudes, float magnitudeDelta, float minMagnitude, float queryMagnitudeSquared, float* results);
float assemble_and_sum_f32_512(const float* data, int dataBase, const unsigned char* baseOffsets, int baseOffsetsLength);
float decoded_cosine_similarity_f32_512(const unsigned char* baseOffsets, int baseOffsetsLength, int clusterCount, const float* partialSums, const float* aMagnitude, float bMagnitude);
void calculate_partial_sums_dot_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums);
void calculate_partial_sums_euclidean_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums);
void calculate_partial_sums_best_dot_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums, float* partialBestDistances);
Expand Down

0 comments on commit feb1e02

Please sign in to comment.