1
1
use super :: rabitq:: { self , Code , Lut } ;
2
+ use crate :: types:: scalar8:: Scalar8Owned ;
2
3
use crate :: vchordrq:: types:: OwnedVector ;
3
4
use base:: distance:: DistanceKind ;
4
5
use base:: simd:: ScalarLike ;
6
+ use base:: vector:: VectorBorrowed ;
5
7
use base:: vector:: { VectOwned , VectorOwned } ;
6
8
use half:: f16;
7
9
use rkyv:: { Archive , ArchiveUnsized , CheckBytes , Deserialize , Serialize } ;
@@ -56,7 +58,7 @@ pub trait Vector: VectorOwned {
56
58
57
59
fn residual ( vector : Self :: Borrowed < ' _ > , center : Self :: Borrowed < ' _ > ) -> Self ;
58
60
59
- fn rabitq_fscan_preprocess ( vector : Self :: Borrowed < ' _ > ) -> Lut ;
61
+ fn rabitq_preprocess ( vector : Self :: Borrowed < ' _ > ) -> Lut ;
60
62
61
63
fn rabitq_code ( dims : u32 , vector : Self :: Borrowed < ' _ > ) -> Code ;
62
64
@@ -129,8 +131,8 @@ impl Vector for VectOwned<f32> {
129
131
Self :: new ( ScalarLike :: vector_sub ( vector. slice ( ) , center. slice ( ) ) )
130
132
}
131
133
132
- fn rabitq_fscan_preprocess ( vector : Self :: Borrowed < ' _ > ) -> Lut {
133
- rabitq:: fscan_preprocess ( vector. slice ( ) )
134
+ fn rabitq_preprocess ( vector : Self :: Borrowed < ' _ > ) -> Lut {
135
+ rabitq:: preprocess ( vector. slice ( ) )
134
136
}
135
137
136
138
fn rabitq_code ( dims : u32 , vector : Self :: Borrowed < ' _ > ) -> Code {
@@ -212,8 +214,8 @@ impl Vector for VectOwned<f16> {
212
214
Self :: new ( ScalarLike :: vector_sub ( vector. slice ( ) , center. slice ( ) ) )
213
215
}
214
216
215
- fn rabitq_fscan_preprocess ( vector : Self :: Borrowed < ' _ > ) -> Lut {
216
- rabitq:: fscan_preprocess ( & f16:: vector_to_f32 ( vector. slice ( ) ) )
217
+ fn rabitq_preprocess ( vector : Self :: Borrowed < ' _ > ) -> Lut {
218
+ rabitq:: preprocess ( & f16:: vector_to_f32 ( vector. slice ( ) ) )
217
219
}
218
220
219
221
fn rabitq_code ( dims : u32 , vector : Self :: Borrowed < ' _ > ) -> Code {
@@ -229,6 +231,147 @@ impl Vector for VectOwned<f16> {
229
231
}
230
232
}
231
233
234
+ impl Vector for Scalar8Owned {
235
+ type Metadata = ( f32 , f32 , f32 , f32 ) ;
236
+
237
+ type Element = u8 ;
238
+
239
+ fn metadata_from_archived (
240
+ archived : & <Self :: Metadata as ArchiveUnsized >:: Archived ,
241
+ ) -> Self :: Metadata {
242
+ ( archived. 0 , archived. 1 , archived. 2 , archived. 3 )
243
+ }
244
+
245
+ fn vector_split ( vector : Self :: Borrowed < ' _ > ) -> ( Self :: Metadata , Vec < & [ Self :: Element ] > ) {
246
+ let code = vector. code ( ) ;
247
+ (
248
+ (
249
+ vector. sum_of_x2 ( ) ,
250
+ vector. k ( ) ,
251
+ vector. b ( ) ,
252
+ vector. sum_of_code ( ) ,
253
+ ) ,
254
+ match code. len ( ) {
255
+ 0 ..=3840 => vec ! [ code] ,
256
+ 3841 ..=5120 => vec ! [ & code[ ..2560 ] , & code[ 2560 ..] ] ,
257
+ 5121 .. => code. chunks ( 7680 ) . collect ( ) ,
258
+ } ,
259
+ )
260
+ }
261
+
262
+ fn vector_merge ( metadata : Self :: Metadata , slice : & [ Self :: Element ] ) -> Self {
263
+ Scalar8Owned :: new (
264
+ metadata. 0 ,
265
+ metadata. 1 ,
266
+ metadata. 2 ,
267
+ metadata. 3 ,
268
+ slice. to_vec ( ) ,
269
+ )
270
+ }
271
+
272
+ fn from_owned ( vector : OwnedVector ) -> Self {
273
+ match vector {
274
+ OwnedVector :: Scalar8 ( x) => x,
275
+ _ => unreachable ! ( ) ,
276
+ }
277
+ }
278
+
279
+ type DistanceAccumulator = ( DistanceKind , u32 , u32 ) ;
280
+
281
+ fn distance_begin ( distance_kind : DistanceKind ) -> Self :: DistanceAccumulator {
282
+ ( distance_kind, 0 , 0 )
283
+ }
284
+
285
+ fn distance_next (
286
+ accumulator : & mut Self :: DistanceAccumulator ,
287
+ left : & [ Self :: Element ] ,
288
+ right : & [ Self :: Element ] ,
289
+ ) {
290
+ match accumulator. 0 {
291
+ DistanceKind :: L2 => accumulator. 1 += base:: simd:: u8:: reduce_sum_of_xy ( left, right) ,
292
+ DistanceKind :: Dot => accumulator. 1 += base:: simd:: u8:: reduce_sum_of_xy ( left, right) ,
293
+ DistanceKind :: Hamming => unreachable ! ( ) ,
294
+ DistanceKind :: Jaccard => unreachable ! ( ) ,
295
+ }
296
+ accumulator. 2 += left. len ( ) as u32 ;
297
+ }
298
+
299
+ fn distance_end (
300
+ accumulator : Self :: DistanceAccumulator ,
301
+ ( sum_of_x2_u, k_u, b_u, sum_of_code_u) : Self :: Metadata ,
302
+ ( sum_of_x2_v, k_v, b_v, sum_of_code_v) : Self :: Metadata ,
303
+ ) -> f32 {
304
+ match accumulator. 0 {
305
+ DistanceKind :: L2 => {
306
+ let xy = k_u * k_v * accumulator. 1 as f32
307
+ + b_u * b_v * accumulator. 2 as f32
308
+ + k_u * b_v * sum_of_code_u
309
+ + b_u * k_v * sum_of_code_v;
310
+ sum_of_x2_u + sum_of_x2_v - 2.0 * xy
311
+ }
312
+ DistanceKind :: Dot => {
313
+ let xy = k_u * k_v * accumulator. 1 as f32
314
+ + b_u * b_v * accumulator. 2 as f32
315
+ + k_u * b_v * sum_of_code_u
316
+ + b_u * k_v * sum_of_code_v;
317
+ -xy
318
+ }
319
+ DistanceKind :: Hamming => unreachable ! ( ) ,
320
+ DistanceKind :: Jaccard => unreachable ! ( ) ,
321
+ }
322
+ }
323
+
324
+ fn random_projection ( vector : Self :: Borrowed < ' _ > ) -> Self {
325
+ vector. own ( )
326
+ }
327
+
328
+ fn residual ( _: Self :: Borrowed < ' _ > , _: Self :: Borrowed < ' _ > ) -> Self {
329
+ unimplemented ! ( )
330
+ }
331
+
332
+ fn rabitq_preprocess ( vector : Self :: Borrowed < ' _ > ) -> Lut {
333
+ let dis_v_2 = vector. sum_of_x2 ( ) ;
334
+ let k = vector. k ( ) * 17.0 ;
335
+ let b = vector. b ( ) ;
336
+ let qvector = vector
337
+ . code ( )
338
+ . iter ( )
339
+ . map ( |& x| ( ( x as u32 + 8 ) / 17 ) as u8 )
340
+ . collect :: < Vec < _ > > ( ) ;
341
+ let qvector_sum = if qvector. len ( ) <= 4369 {
342
+ base:: simd:: u8:: reduce_sum_of_x_as_u16 ( & qvector) as f32
343
+ } else {
344
+ base:: simd:: u8:: reduce_sum_of_x_as_u32 ( & qvector) as f32
345
+ } ;
346
+ ( dis_v_2, b, k, qvector_sum, rabitq:: binarize ( & qvector) )
347
+ }
348
+
349
+ fn rabitq_code ( dims : u32 , vector : Self :: Borrowed < ' _ > ) -> Code {
350
+ let dequantized = vector
351
+ . code ( )
352
+ . iter ( )
353
+ . map ( |& x| vector. k ( ) * x as f32 + vector. b ( ) )
354
+ . collect :: < Vec < _ > > ( ) ;
355
+ rabitq:: code ( dims, & dequantized)
356
+ }
357
+
358
+ fn build_to_vecf32 ( vector : Self :: Borrowed < ' _ > ) -> Vec < f32 > {
359
+ vector
360
+ . code ( )
361
+ . iter ( )
362
+ . map ( |& x| vector. k ( ) * x as f32 + vector. b ( ) )
363
+ . collect ( )
364
+ }
365
+
366
+ fn build_from_vecf32 ( x : & [ f32 ] ) -> Self {
367
+ let sum_of_x2 = f32:: reduce_sum_of_x2 ( x) ;
368
+ let ( k, b, code) =
369
+ base:: simd:: quantize:: quantize ( f32:: vector_to_f32_borrowed ( x) . as_ref ( ) , 255.0 ) ;
370
+ let sum_of_code = base:: simd:: u8:: reduce_sum_of_x_as_u32 ( & code) as f32 ;
371
+ Self :: new ( sum_of_x2, k, b, sum_of_code, code)
372
+ }
373
+ }
374
+
232
375
#[ derive( Clone , PartialEq , Archive , Serialize , Deserialize ) ]
233
376
#[ archive( check_bytes) ]
234
377
pub struct MetaTuple {
0 commit comments