Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reenable SimdOps.assembleAndSum; implement Panama/Native equivalent for CosineDecoder acceleration #368

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -131,18 +131,10 @@ public float similarityTo(int node2) {
}

protected float decodedCosine(int node2) {
float sum = 0.0f;
float aMag = 0.0f;

ByteSequence<?> encoded = cv.get(node2);

for (int m = 0; m < encoded.length(); ++m) {
int centroidIndex = Byte.toUnsignedInt(encoded.get(m));
sum += partialSums.get((m * cv.pq.getClusterCount()) + centroidIndex);
aMag += aMagnitude.get((m * cv.pq.getClusterCount()) + centroidIndex);
}

return (float) (sum / Math.sqrt(aMag * bMagnitude));
return VectorUtil.decodedCosineSimilarity(encoded, cv.pq.getClusterCount(), partialSums, aMagnitude, bMagnitude);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -194,4 +194,8 @@ public static float max(VectorFloat<?> v) {
public static float min(VectorFloat<?> v) {
return impl.min(v);
}

public static float decodedCosineSimilarity(ByteSequence<?> encoded, int clusterCount, VectorFloat<?> partialSums, VectorFloat<?> aMagnitude, float bMagnitude) {
return impl.decodedCosineSimilarity(encoded, clusterCount, partialSums, aMagnitude, bMagnitude);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -199,4 +199,18 @@ default void bulkShuffleQuantizedSimilarityCosine(ByteSequence<?> shuffles, int
float max(VectorFloat<?> v);
float min(VectorFloat<?> v);

default float decodedCosineSimilarity(ByteSequence<?> encoded, int clusterCount, VectorFloat<?> partialSums, VectorFloat<?> aMagnitude, float bMagnitude)
{
float sum = 0.0f;
float aMag = 0.0f;

for (int m = 0; m < encoded.length(); ++m) {
int centroidIndex = Byte.toUnsignedInt(encoded.get(m));
var index = m * clusterCount + centroidIndex;
sum += partialSums.get(index);
aMag += aMagnitude.get(index);
}

return (float) (sum / Math.sqrt(aMag * bMagnitude));
}
}
43 changes: 43 additions & 0 deletions jvector-native/src/main/c/jvector_simd.c
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,49 @@ float assemble_and_sum_f32_512(const float* data, int dataBase, const unsigned c
return res;
}

float decoded_cosine_similarity_f32_512(const unsigned char* baseOffsets, int baseOffsetsLength, int clusterCount, const float* partialSums, const float* aMagnitude, float bMagnitude) {
__m512 sum = _mm512_setzero_ps();
__m512 vaMagnitude = _mm512_setzero_ps();
int i = 0;
int limit = baseOffsetsLength - (baseOffsetsLength % 16);
__m512i indexRegister = initialIndexRegister;
__m512i scale = _mm512_set1_epi32(clusterCount);


for (; i < limit; i += 16) {
// Load and convert baseOffsets to integers
__m128i baseOffsetsRaw = _mm_loadu_si128((__m128i *)(baseOffsets + i));
__m512i baseOffsetsInt = _mm512_cvtepu8_epi32(baseOffsetsRaw);

indexRegister = _mm512_add_epi32(indexRegister, indexIncrement);
// Scale the baseOffsets by the cluster count
__m512i scaledOffsets = _mm512_mullo_epi32(indexRegister, scale);

// Calculate the final convOffsets by adding the scaled indexes and the base offsets
__m512i convOffsets = _mm512_add_epi32(scaledOffsets, baseOffsetsInt);

// Gather and sum values for partial sums and a magnitude
__m512 partialSumVals = _mm512_i32gather_ps(convOffsets, partialSums, 4);
sum = _mm512_add_ps(sum, partialSumVals);

__m512 aMagnitudeVals = _mm512_i32gather_ps(convOffsets, aMagnitude, 4);
vaMagnitude = _mm512_add_ps(vaMagnitude, aMagnitudeVals);
}

// Reduce sums
float sumResult = _mm512_reduce_add_ps(sum);
float aMagnitudeResult = _mm512_reduce_add_ps(vaMagnitude);

// Handle the remaining elements
for (; i < baseOffsetsLength; i++) {
int offset = clusterCount * i + baseOffsets[i];
sumResult += partialSums[offset];
aMagnitudeResult += aMagnitude[offset];
}

return sumResult / sqrtf(aMagnitudeResult * bMagnitude);
}

void calculate_partial_sums_dot_f32_512(const float* codebook, int codebookIndex, int size, int clusterCount, const float* query, int queryOffset, float* partialSums) {
int codebookBase = codebookIndex * clusterCount;
for (int i = 0; i < clusterCount; i++) {
Expand Down
1 change: 1 addition & 0 deletions jvector-native/src/main/c/jvector_simd.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ void bulk_quantized_shuffle_dot_f32_512(const unsigned char* shuffles, int codeb
void bulk_quantized_shuffle_euclidean_f32_512(const unsigned char* shuffles, int codebookCount, const char* quantizedPartials, float delta, float minDistance, float* results);
void bulk_quantized_shuffle_cosine_f32_512(const unsigned char* shuffles, int codebookCount, const char* quantizedPartialSums, float sumDelta, float minDistance, const char* quantizedPartialMagnitudes, float magnitudeDelta, float minMagnitude, float queryMagnitudeSquared, float* results);
float assemble_and_sum_f32_512(const float* data, int dataBase, const unsigned char* baseOffsets, int baseOffsetsLength);
float decoded_cosine_similarity_f32_512(const unsigned char* baseOffsets, int baseOffsetsLength, int clusterCount, const float* partialSums, const float* aMagnitude, float bMagnitude);
void calculate_partial_sums_dot_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums);
void calculate_partial_sums_euclidean_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums);
void calculate_partial_sums_best_dot_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums, float* partialBestDistances);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,4 +155,10 @@ public void bulkShuffleQuantizedSimilarityCosine(ByteSequence<?> shuffles, int c
NativeSimdOps.bulk_quantized_shuffle_cosine_f32_512(((MemorySegmentByteSequence) shuffles).get(), codebookCount, ((MemorySegmentByteSequence) quantizedPartialSums).get(), sumDelta, minDistance,
((MemorySegmentByteSequence) quantizedPartialSquaredMagnitudes).get(), magnitudeDelta, minMagnitude, queryMagnitudeSquared, ((MemorySegmentVectorFloat) results).get());
}

@Override
public float decodedCosineSimilarity(ByteSequence<?> encoded, int clusterCount, VectorFloat<?> partialSums, VectorFloat<?> aMagnitude, float bMagnitude)
{
return NativeSimdOps.decoded_cosine_similarity_f32_512(((MemorySegmentByteSequence) encoded).get(), encoded.length(), clusterCount, ((MemorySegmentVectorFloat) partialSums).get(), ((MemorySegmentVectorFloat) aMagnitude).get(), bMagnitude);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,58 @@ public static float assemble_and_sum_f32_512(MemorySegment data, int dataBase, M
}
}

private static class decoded_cosine_similarity_f32_512 {
public static final FunctionDescriptor DESC = FunctionDescriptor.of(
NativeSimdOps.C_FLOAT,
NativeSimdOps.C_POINTER,
NativeSimdOps.C_INT,
NativeSimdOps.C_INT,
NativeSimdOps.C_POINTER,
NativeSimdOps.C_POINTER,
NativeSimdOps.C_FLOAT
);

public static final MethodHandle HANDLE = Linker.nativeLinker().downcallHandle(
NativeSimdOps.findOrThrow("decoded_cosine_similarity_f32_512"),
DESC, Linker.Option.critical(true));
}

/**
* Function descriptor for:
* {@snippet lang=c :
* float decoded_cosine_similarity_f32_512(const unsigned char *baseOffsets, int baseOffsetsLength, int clusterCount, const float *partialSums, const float *aMagnitude, float bMagnitude)
* }
*/
public static FunctionDescriptor decoded_cosine_similarity_f32_512$descriptor() {
return decoded_cosine_similarity_f32_512.DESC;
}

/**
* Downcall method handle for:
* {@snippet lang=c :
* float decoded_cosine_similarity_f32_512(const unsigned char *baseOffsets, int baseOffsetsLength, int clusterCount, const float *partialSums, const float *aMagnitude, float bMagnitude)
* }
*/
public static MethodHandle decoded_cosine_similarity_f32_512$handle() {
return decoded_cosine_similarity_f32_512.HANDLE;
}
/**
* {@snippet lang=c :
* float decoded_cosine_similarity_f32_512(const unsigned char *baseOffsets, int baseOffsetsLength, int clusterCount, const float *partialSums, const float *aMagnitude, float bMagnitude)
* }
*/
public static float decoded_cosine_similarity_f32_512(MemorySegment baseOffsets, int baseOffsetsLength, int clusterCount, MemorySegment partialSums, MemorySegment aMagnitude, float bMagnitude) {
var mh$ = decoded_cosine_similarity_f32_512.HANDLE;
try {
if (TRACE_DOWNCALLS) {
traceDowncall("decoded_cosine_similarity_f32_512", baseOffsets, baseOffsetsLength, clusterCount, partialSums, aMagnitude, bMagnitude);
}
return (float)mh$.invokeExact(baseOffsets, baseOffsetsLength, clusterCount, partialSums, aMagnitude, bMagnitude);
} catch (Throwable ex$) {
throw new AssertionError("should not reach here", ex$);
}
}

private static class calculate_partial_sums_dot_f32_512 {
public static final FunctionDescriptor DESC = FunctionDescriptor.ofVoid(
NativeSimdOps.C_POINTER,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,7 @@ public VectorFloat<?> sub(VectorFloat<?> a, int aOffset, VectorFloat<?> b, int b

@Override
public float assembleAndSum(VectorFloat<?> data, int dataBase, ByteSequence<?> baseOffsets) {
float sum = 0f;
for (int i = 0; i < baseOffsets.length(); i++) {
sum += data.get(dataBase * i + Byte.toUnsignedInt(baseOffsets.get(i)));
}
return sum;
return SimdOps.assembleAndSum(((ArrayVectorFloat) data).get(), dataBase, ((ArrayByteSequence) baseOffsets).get());
}

@Override
Expand Down Expand Up @@ -159,5 +155,11 @@ public void calculatePartialSums(VectorFloat<?> codebook, int codebookIndex, int
public void quantizePartials(float delta, VectorFloat<?> partials, VectorFloat<?> partialBases, ByteSequence<?> quantizedPartials) {
SimdOps.quantizePartials(delta, (ArrayVectorFloat) partials, (ArrayVectorFloat) partialBases, (ArrayByteSequence) quantizedPartials);
}

@Override
public float decodedCosineSimilarity(ByteSequence<?> encoded, int clusterCount, VectorFloat<?> partialSums, VectorFloat<?> aMagnitude, float bMagnitude)
{
return SimdOps.decodedCosineSimilarity((ArrayByteSequence) encoded, clusterCount, (ArrayVectorFloat) partialSums, (ArrayVectorFloat) aMagnitude, bMagnitude);
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -525,18 +525,18 @@ static float assembleAndSum512(float[] data, int dataBase, byte[] baseOffsets) {
FloatVector sum = FloatVector.zero(FloatVector.SPECIES_512);
int i = 0;
int limit = ByteVector.SPECIES_128.loopBound(baseOffsets.length);
var scale = IntVector.zero(IntVector.SPECIES_512).addIndex(dataBase);

for (; i < limit; i += ByteVector.SPECIES_128.length()) {
var scale = IntVector.zero(IntVector.SPECIES_512).addIndex(1).add(i).mul(dataBase);

ByteVector.fromArray(ByteVector.SPECIES_128, baseOffsets, i)
.convertShape(VectorOperators.B2I, IntVector.SPECIES_512, 0)
.lanewise(VectorOperators.AND, BYTE_TO_INT_MASK_512)
.reinterpretAsInts()
.add(scale)
.intoArray(convOffsets,0);

sum = sum.add(FloatVector.fromArray(FloatVector.SPECIES_512, data, 0, convOffsets, 0));
var offset = i * dataBase;
sum = sum.add(FloatVector.fromArray(FloatVector.SPECIES_512, data, offset, convOffsets, 0));
}

float res = sum.reduceLanes(VectorOperators.ADD);
Expand All @@ -553,9 +553,9 @@ static float assembleAndSum256(float[] data, int dataBase, byte[] baseOffsets) {
FloatVector sum = FloatVector.zero(FloatVector.SPECIES_256);
int i = 0;
int limit = ByteVector.SPECIES_64.loopBound(baseOffsets.length);
var scale = IntVector.zero(IntVector.SPECIES_256).addIndex(dataBase);

for (; i < limit; i += ByteVector.SPECIES_64.length()) {
var scale = IntVector.zero(IntVector.SPECIES_256).addIndex(1).add(i).mul(dataBase);

ByteVector.fromArray(ByteVector.SPECIES_64, baseOffsets, i)
.convertShape(VectorOperators.B2I, IntVector.SPECIES_256, 0)
Expand All @@ -564,7 +564,8 @@ static float assembleAndSum256(float[] data, int dataBase, byte[] baseOffsets) {
.add(scale)
.intoArray(convOffsets,0);

sum = sum.add(FloatVector.fromArray(FloatVector.SPECIES_256, data, 0, convOffsets, 0));
var offset = i * dataBase;
sum = sum.add(FloatVector.fromArray(FloatVector.SPECIES_256, data, offset, convOffsets, 0));
}

float res = sum.reduceLanes(VectorOperators.ADD);
Expand Down Expand Up @@ -658,4 +659,88 @@ public static void quantizePartials(float delta, ArrayVectorFloat partials, Arra
}
}
}

public static float decodedCosineSimilarity(ArrayByteSequence encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) {
return HAS_AVX512
? decodedCosineSimilarity512(encoded, clusterCount, partialSums, aMagnitude, bMagnitude)
: decodedCosineSimilarity256(encoded, clusterCount, partialSums, aMagnitude, bMagnitude);
}

public static float decodedCosineSimilarity512(ArrayByteSequence encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) {
var sum = FloatVector.zero(FloatVector.SPECIES_512);
var vaMagnitude = FloatVector.zero(FloatVector.SPECIES_512);
var baseOffsets = encoded.get();
var partialSumsArray = partialSums.get();
var aMagnitudeArray = aMagnitude.get();

int[] convOffsets = scratchInt512.get();
int i = 0;
int limit = ByteVector.SPECIES_128.loopBound(baseOffsets.length);

var scale = IntVector.zero(IntVector.SPECIES_512).addIndex(clusterCount);

for (; i < limit; i += ByteVector.SPECIES_128.length()) {

ByteVector.fromArray(ByteVector.SPECIES_128, baseOffsets, i)
.convertShape(VectorOperators.B2I, IntVector.SPECIES_512, 0)
.lanewise(VectorOperators.AND, BYTE_TO_INT_MASK_512)
.reinterpretAsInts()
.add(scale)
.intoArray(convOffsets,0);

var offset = i * clusterCount;
sum = sum.add(FloatVector.fromArray(FloatVector.SPECIES_512, partialSumsArray, offset, convOffsets, 0));
vaMagnitude = vaMagnitude.add(FloatVector.fromArray(FloatVector.SPECIES_512, aMagnitudeArray, offset, convOffsets, 0));
}

float sumResult = sum.reduceLanes(VectorOperators.ADD);
float aMagnitudeResult = vaMagnitude.reduceLanes(VectorOperators.ADD);

for (; i < baseOffsets.length; i++) {
int offset = clusterCount * i + Byte.toUnsignedInt(baseOffsets[i]);
sumResult += partialSumsArray[offset];
aMagnitudeResult += aMagnitudeArray[offset];
}

return (float) (sumResult / Math.sqrt(aMagnitudeResult * bMagnitude));
}

public static float decodedCosineSimilarity256(ArrayByteSequence encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) {
var sum = FloatVector.zero(FloatVector.SPECIES_256);
var vaMagnitude = FloatVector.zero(FloatVector.SPECIES_256);
var baseOffsets = encoded.get();
var partialSumsArray = partialSums.get();
var aMagnitudeArray = aMagnitude.get();

int[] convOffsets = scratchInt256.get();
int i = 0;
int limit = ByteVector.SPECIES_64.loopBound(baseOffsets.length);

var scale = IntVector.zero(IntVector.SPECIES_256).addIndex(clusterCount);

for (; i < limit; i += ByteVector.SPECIES_64.length()) {

ByteVector.fromArray(ByteVector.SPECIES_64, baseOffsets, i)
.convertShape(VectorOperators.B2I, IntVector.SPECIES_256, 0)
.lanewise(VectorOperators.AND, BYTE_TO_INT_MASK_256)
.reinterpretAsInts()
.add(scale)
.intoArray(convOffsets,0);

var offset = i * clusterCount;
sum = sum.add(FloatVector.fromArray(FloatVector.SPECIES_256, partialSumsArray, offset, convOffsets, 0));
vaMagnitude = vaMagnitude.add(FloatVector.fromArray(FloatVector.SPECIES_256, aMagnitudeArray, offset, convOffsets, 0));
}

float sumResult = sum.reduceLanes(VectorOperators.ADD);
float aMagnitudeResult = vaMagnitude.reduceLanes(VectorOperators.ADD);

for (; i < baseOffsets.length; i++) {
int offset = clusterCount * i + Byte.toUnsignedInt(baseOffsets[i]);
sumResult += partialSumsArray[offset];
aMagnitudeResult += aMagnitudeArray[offset];
}

return (float) (sumResult / Math.sqrt(aMagnitudeResult * bMagnitude));
}
}
Loading