Skip to content

Commit 20ed4ac

Browse files
Break decodedCosineSimilarity out by HAS_AVX512
1 parent f6eca79 commit 20ed4ac

File tree

1 file changed

+45
-0
lines changed
  • jvector-twenty/src/main/java/io/github/jbellis/jvector/vector

1 file changed

+45
-0
lines changed

jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -660,6 +660,12 @@ public static void quantizePartials(float delta, ArrayVectorFloat partials, Arra
660660
}
661661

662662
public static float decodedCosineSimilarity(ArrayByteSequence encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) {
663+
return HAS_AVX512
664+
? decodedCosineSimilarity512(encoded, clusterCount, partialSums, aMagnitude, bMagnitude)
665+
: decodedCosineSimilarity256(encoded, clusterCount, partialSums, aMagnitude, bMagnitude);
666+
}
667+
668+
public static float decodedCosineSimilarity512(ArrayByteSequence encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) {
663669
var sum = FloatVector.zero(FloatVector.SPECIES_512);
664670
var vaMagnitude = FloatVector.zero(FloatVector.SPECIES_512);
665671
var baseOffsets = encoded.get();
@@ -697,4 +703,43 @@ public static float decodedCosineSimilarity(ArrayByteSequence encoded, int clust
697703

698704
return (float) (sumResult / Math.sqrt(aMagnitudeResult * bMagnitude));
699705
}
706+
707+
public static float decodedCosineSimilarity256(ArrayByteSequence encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) {
708+
var sum = FloatVector.zero(FloatVector.SPECIES_256);
709+
var vaMagnitude = FloatVector.zero(FloatVector.SPECIES_256);
710+
var baseOffsets = encoded.get();
711+
var partialSumsArray = partialSums.get();
712+
var aMagnitudeArray = aMagnitude.get();
713+
714+
int[] convOffsets = scratchInt256.get();
715+
int i = 0;
716+
int limit = ByteVector.SPECIES_64.loopBound(baseOffsets.length);
717+
718+
var scale = IntVector.zero(IntVector.SPECIES_256).addIndex(clusterCount);
719+
720+
for (; i < limit; i += ByteVector.SPECIES_64.length()) {
721+
722+
ByteVector.fromArray(ByteVector.SPECIES_64, baseOffsets, i)
723+
.convertShape(VectorOperators.B2I, IntVector.SPECIES_256, 0)
724+
.lanewise(VectorOperators.AND, BYTE_TO_INT_MASK_256)
725+
.reinterpretAsInts()
726+
.add(scale)
727+
.intoArray(convOffsets,0);
728+
729+
var offset = i * clusterCount;
730+
sum = sum.add(FloatVector.fromArray(FloatVector.SPECIES_256, partialSumsArray, offset, convOffsets, 0));
731+
vaMagnitude = vaMagnitude.add(FloatVector.fromArray(FloatVector.SPECIES_256, aMagnitudeArray, offset, convOffsets, 0));
732+
}
733+
734+
float sumResult = sum.reduceLanes(VectorOperators.ADD);
735+
float aMagnitudeResult = vaMagnitude.reduceLanes(VectorOperators.ADD);
736+
737+
for (; i < baseOffsets.length; i++) {
738+
int offset = clusterCount * i + Byte.toUnsignedInt(baseOffsets[i]);
739+
sumResult += partialSumsArray[offset];
740+
aMagnitudeResult += aMagnitudeArray[offset];
741+
}
742+
743+
return (float) (sumResult / Math.sqrt(aMagnitudeResult * bMagnitude));
744+
}
700745
}

0 commit comments

Comments
 (0)