@@ -660,6 +660,12 @@ public static void quantizePartials(float delta, ArrayVectorFloat partials, Arra
660
660
}
661
661
662
662
public static float decodedCosineSimilarity (ArrayByteSequence encoded , int clusterCount , ArrayVectorFloat partialSums , ArrayVectorFloat aMagnitude , float bMagnitude ) {
663
+ return HAS_AVX512
664
+ ? decodedCosineSimilarity512 (encoded , clusterCount , partialSums , aMagnitude , bMagnitude )
665
+ : decodedCosineSimilarity256 (encoded , clusterCount , partialSums , aMagnitude , bMagnitude );
666
+ }
667
+
668
+ public static float decodedCosineSimilarity512 (ArrayByteSequence encoded , int clusterCount , ArrayVectorFloat partialSums , ArrayVectorFloat aMagnitude , float bMagnitude ) {
663
669
var sum = FloatVector .zero (FloatVector .SPECIES_512 );
664
670
var vaMagnitude = FloatVector .zero (FloatVector .SPECIES_512 );
665
671
var baseOffsets = encoded .get ();
@@ -697,4 +703,43 @@ public static float decodedCosineSimilarity(ArrayByteSequence encoded, int clust
697
703
698
704
return (float ) (sumResult / Math .sqrt (aMagnitudeResult * bMagnitude ));
699
705
}
706
+
707
+ public static float decodedCosineSimilarity256 (ArrayByteSequence encoded , int clusterCount , ArrayVectorFloat partialSums , ArrayVectorFloat aMagnitude , float bMagnitude ) {
708
+ var sum = FloatVector .zero (FloatVector .SPECIES_256 );
709
+ var vaMagnitude = FloatVector .zero (FloatVector .SPECIES_256 );
710
+ var baseOffsets = encoded .get ();
711
+ var partialSumsArray = partialSums .get ();
712
+ var aMagnitudeArray = aMagnitude .get ();
713
+
714
+ int [] convOffsets = scratchInt256 .get ();
715
+ int i = 0 ;
716
+ int limit = ByteVector .SPECIES_64 .loopBound (baseOffsets .length );
717
+
718
+ var scale = IntVector .zero (IntVector .SPECIES_256 ).addIndex (clusterCount );
719
+
720
+ for (; i < limit ; i += ByteVector .SPECIES_64 .length ()) {
721
+
722
+ ByteVector .fromArray (ByteVector .SPECIES_64 , baseOffsets , i )
723
+ .convertShape (VectorOperators .B2I , IntVector .SPECIES_256 , 0 )
724
+ .lanewise (VectorOperators .AND , BYTE_TO_INT_MASK_256 )
725
+ .reinterpretAsInts ()
726
+ .add (scale )
727
+ .intoArray (convOffsets ,0 );
728
+
729
+ var offset = i * clusterCount ;
730
+ sum = sum .add (FloatVector .fromArray (FloatVector .SPECIES_256 , partialSumsArray , offset , convOffsets , 0 ));
731
+ vaMagnitude = vaMagnitude .add (FloatVector .fromArray (FloatVector .SPECIES_256 , aMagnitudeArray , offset , convOffsets , 0 ));
732
+ }
733
+
734
+ float sumResult = sum .reduceLanes (VectorOperators .ADD );
735
+ float aMagnitudeResult = vaMagnitude .reduceLanes (VectorOperators .ADD );
736
+
737
+ for (; i < baseOffsets .length ; i ++) {
738
+ int offset = clusterCount * i + Byte .toUnsignedInt (baseOffsets [i ]);
739
+ sumResult += partialSumsArray [offset ];
740
+ aMagnitudeResult += aMagnitudeArray [offset ];
741
+ }
742
+
743
+ return (float ) (sumResult / Math .sqrt (aMagnitudeResult * bMagnitude ));
744
+ }
700
745
}
0 commit comments