Skip to content

Commit

Permalink
Break decodedCosineSimilarity out by HAS_AVX512
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeljmarshall committed Nov 14, 2024
1 parent f6eca79 commit 20ed4ac
Showing 1 changed file with 45 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,12 @@ public static void quantizePartials(float delta, ArrayVectorFloat partials, Arra
}

public static float decodedCosineSimilarity(ArrayByteSequence encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) {
return HAS_AVX512
? decodedCosineSimilarity512(encoded, clusterCount, partialSums, aMagnitude, bMagnitude)
: decodedCosineSimilarity256(encoded, clusterCount, partialSums, aMagnitude, bMagnitude);
}

public static float decodedCosineSimilarity512(ArrayByteSequence encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) {
var sum = FloatVector.zero(FloatVector.SPECIES_512);
var vaMagnitude = FloatVector.zero(FloatVector.SPECIES_512);
var baseOffsets = encoded.get();
Expand Down Expand Up @@ -697,4 +703,43 @@ public static float decodedCosineSimilarity(ArrayByteSequence encoded, int clust

return (float) (sumResult / Math.sqrt(aMagnitudeResult * bMagnitude));
}

public static float decodedCosineSimilarity256(ArrayByteSequence encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) {
var sum = FloatVector.zero(FloatVector.SPECIES_256);
var vaMagnitude = FloatVector.zero(FloatVector.SPECIES_256);
var baseOffsets = encoded.get();
var partialSumsArray = partialSums.get();
var aMagnitudeArray = aMagnitude.get();

int[] convOffsets = scratchInt256.get();
int i = 0;
int limit = ByteVector.SPECIES_64.loopBound(baseOffsets.length);

var scale = IntVector.zero(IntVector.SPECIES_256).addIndex(clusterCount);

for (; i < limit; i += ByteVector.SPECIES_64.length()) {

ByteVector.fromArray(ByteVector.SPECIES_64, baseOffsets, i)
.convertShape(VectorOperators.B2I, IntVector.SPECIES_256, 0)
.lanewise(VectorOperators.AND, BYTE_TO_INT_MASK_256)
.reinterpretAsInts()
.add(scale)
.intoArray(convOffsets,0);

var offset = i * clusterCount;
sum = sum.add(FloatVector.fromArray(FloatVector.SPECIES_256, partialSumsArray, offset, convOffsets, 0));
vaMagnitude = vaMagnitude.add(FloatVector.fromArray(FloatVector.SPECIES_256, aMagnitudeArray, offset, convOffsets, 0));
}

float sumResult = sum.reduceLanes(VectorOperators.ADD);
float aMagnitudeResult = vaMagnitude.reduceLanes(VectorOperators.ADD);

for (; i < baseOffsets.length; i++) {
int offset = clusterCount * i + Byte.toUnsignedInt(baseOffsets[i]);
sumResult += partialSumsArray[offset];
aMagnitudeResult += aMagnitudeArray[offset];
}

return (float) (sumResult / Math.sqrt(aMagnitudeResult * bMagnitude));
}
}

0 comments on commit 20ed4ac

Please sign in to comment.