Skip to content

Commit 07f4a3e

Browse files
committed
feat: indexing on scalar8
Signed-off-by: usamoi <[email protected]>
1 parent f88b1ac commit 07f4a3e

File tree

11 files changed

+308
-20
lines changed

11 files changed

+308
-20
lines changed

src/sql/finalize.sql

+21
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,9 @@ CREATE OPERATOR FAMILY vector_cosine_ops USING vchordrq;
156156
CREATE OPERATOR FAMILY halfvec_l2_ops USING vchordrq;
157157
CREATE OPERATOR FAMILY halfvec_ip_ops USING vchordrq;
158158
CREATE OPERATOR FAMILY halfvec_cosine_ops USING vchordrq;
159+
CREATE OPERATOR FAMILY scalar8_l2_ops USING vchordrq;
160+
CREATE OPERATOR FAMILY scalar8_ip_ops USING vchordrq;
161+
CREATE OPERATOR FAMILY scalar8_cosine_ops USING vchordrq;
159162

160163
CREATE OPERATOR FAMILY vector_l2_ops USING Vchordrqfscan;
161164
CREATE OPERATOR FAMILY vector_ip_ops USING Vchordrqfscan;
@@ -199,6 +202,24 @@ CREATE OPERATOR CLASS halfvec_cosine_ops
199202
OPERATOR 2 <<=>> (halfvec, sphere_halfvec) FOR SEARCH,
200203
FUNCTION 1 _vchordrq_support_halfvec_cosine_ops();
201204

205+
CREATE OPERATOR CLASS scalar8_l2_ops
206+
FOR TYPE scalar8 USING vchordrq FAMILY scalar8_l2_ops AS
207+
OPERATOR 1 <-> (scalar8, scalar8) FOR ORDER BY float_ops,
208+
OPERATOR 2 <<->> (scalar8, sphere_scalar8) FOR SEARCH,
209+
FUNCTION 1 _vchordrq_support_scalar8_l2_ops();
210+
211+
CREATE OPERATOR CLASS scalar8_ip_ops
212+
FOR TYPE scalar8 USING vchordrq FAMILY scalar8_ip_ops AS
213+
OPERATOR 1 <#> (scalar8, scalar8) FOR ORDER BY float_ops,
214+
OPERATOR 2 <<#>> (scalar8, sphere_scalar8) FOR SEARCH,
215+
FUNCTION 1 _vchordrq_support_scalar8_ip_ops();
216+
217+
CREATE OPERATOR CLASS scalar8_cosine_ops
218+
FOR TYPE scalar8 USING vchordrq FAMILY scalar8_cosine_ops AS
219+
OPERATOR 1 <=> (scalar8, scalar8) FOR ORDER BY float_ops,
220+
OPERATOR 2 <<=>> (scalar8, sphere_scalar8) FOR SEARCH,
221+
FUNCTION 1 _vchordrq_support_scalar8_cosine_ops();
222+
202223
CREATE OPERATOR CLASS vector_l2_ops
203224
FOR TYPE vector USING Vchordrqfscan FAMILY vector_l2_ops AS
204225
OPERATOR 1 <-> (vector, vector) FOR ORDER BY float_ops,

src/vchordrq/algorithm/insert.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use crate::postgres::Relation;
2-
use crate::vchordrq::algorithm::rabitq::fscan_process_lowerbound;
2+
use crate::vchordrq::algorithm::rabitq::process_lowerbound;
33
use crate::vchordrq::algorithm::tuples::*;
44
use crate::vchordrq::algorithm::vectors;
55
use base::always_equal::AlwaysEqual;
@@ -31,7 +31,7 @@ pub fn insert<V: Vector>(
3131
let vector = vector.as_borrowed();
3232
let is_residual = meta_tuple.is_residual;
3333
let default_lut = if !is_residual {
34-
Some(V::rabitq_fscan_preprocess(vector))
34+
Some(V::rabitq_preprocess(vector))
3535
} else {
3636
None
3737
};
@@ -74,7 +74,7 @@ pub fn insert<V: Vector>(
7474
let mut results = Vec::new();
7575
{
7676
let lut = if is_residual {
77-
&V::rabitq_fscan_preprocess(
77+
&V::rabitq_preprocess(
7878
V::residual(vector, list.1.as_ref().map(|x| x.as_borrowed()).unwrap())
7979
.as_borrowed(),
8080
)
@@ -91,7 +91,7 @@ pub fn insert<V: Vector>(
9191
.map(rkyv::check_archived_root::<Height1Tuple>)
9292
.expect("data corruption")
9393
.expect("data corruption");
94-
let lowerbounds = fscan_process_lowerbound(
94+
let lowerbounds = process_lowerbound(
9595
distance_kind,
9696
dims,
9797
lut,

src/vchordrq/algorithm/rabitq.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -61,19 +61,19 @@ pub fn code(dims: u32, vector: &[f32]) -> Code {
6161

6262
pub type Lut = (f32, f32, f32, f32, (Vec<u64>, Vec<u64>, Vec<u64>, Vec<u64>));
6363

64-
pub fn fscan_preprocess(vector: &[f32]) -> Lut {
64+
pub fn preprocess(vector: &[f32]) -> Lut {
6565
use base::simd::quantize;
6666
let dis_v_2 = f32::reduce_sum_of_x2(vector);
6767
let (k, b, qvector) = quantize::quantize(vector, 15.0);
68-
let qvector_sum = if vector.len() <= 4369 {
68+
let qvector_sum = if qvector.len() <= 4369 {
6969
base::simd::u8::reduce_sum_of_x_as_u16(&qvector) as f32
7070
} else {
7171
base::simd::u8::reduce_sum_of_x_as_u32(&qvector) as f32
7272
};
7373
(dis_v_2, b, k, qvector_sum, binarize(&qvector))
7474
}
7575

76-
pub fn fscan_process_lowerbound(
76+
pub fn process_lowerbound(
7777
distance_kind: DistanceKind,
7878
_dims: u32,
7979
lut: &Lut,
@@ -104,7 +104,7 @@ pub fn fscan_process_lowerbound(
104104
}
105105
}
106106

107-
fn binarize(vector: &[u8]) -> (Vec<u64>, Vec<u64>, Vec<u64>, Vec<u64>) {
107+
pub fn binarize(vector: &[u8]) -> (Vec<u64>, Vec<u64>, Vec<u64>, Vec<u64>) {
108108
let n = vector.len();
109109
let mut t0 = vec![0u64; n.div_ceil(64)];
110110
let mut t1 = vec![0u64; n.div_ceil(64)];

src/vchordrq/algorithm/scan.rs

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use crate::postgres::Relation;
2-
use crate::vchordrq::algorithm::rabitq::fscan_process_lowerbound;
2+
use crate::vchordrq::algorithm::rabitq::process_lowerbound;
33
use crate::vchordrq::algorithm::tuples::*;
44
use crate::vchordrq::algorithm::vectors;
55
use base::always_equal::AlwaysEqual;
@@ -32,7 +32,7 @@ pub fn scan<V: Vector>(
3232
let vector = V::random_projection(vector);
3333
let is_residual = meta_tuple.is_residual;
3434
let default_lut = if !is_residual {
35-
Some(V::rabitq_fscan_preprocess(vector.as_borrowed()))
35+
Some(V::rabitq_preprocess(vector.as_borrowed()))
3636
} else {
3737
None
3838
};
@@ -53,7 +53,7 @@ pub fn scan<V: Vector>(
5353
let mut results = Vec::new();
5454
for list in lists {
5555
let lut = if is_residual {
56-
&V::rabitq_fscan_preprocess(
56+
&V::rabitq_preprocess(
5757
V::residual(
5858
vector.as_borrowed(),
5959
list.1.as_ref().map(|x| x.as_borrowed()).unwrap(),
@@ -73,7 +73,7 @@ pub fn scan<V: Vector>(
7373
.map(rkyv::check_archived_root::<Height1Tuple>)
7474
.expect("data corruption")
7575
.expect("data corruption");
76-
let lowerbounds = fscan_process_lowerbound(
76+
let lowerbounds = process_lowerbound(
7777
distance_kind,
7878
dims,
7979
lut,
@@ -125,7 +125,7 @@ pub fn scan<V: Vector>(
125125
let mut results = Vec::new();
126126
for list in lists {
127127
let lut = if is_residual {
128-
&V::rabitq_fscan_preprocess(
128+
&V::rabitq_preprocess(
129129
V::residual(
130130
vector.as_borrowed(),
131131
list.1.as_ref().map(|x| x.as_borrowed()).unwrap(),
@@ -145,7 +145,7 @@ pub fn scan<V: Vector>(
145145
.map(rkyv::check_archived_root::<Height0Tuple>)
146146
.expect("data corruption")
147147
.expect("data corruption");
148-
let lowerbounds = fscan_process_lowerbound(
148+
let lowerbounds = process_lowerbound(
149149
distance_kind,
150150
dims,
151151
lut,

src/vchordrq/algorithm/tuples.rs

+148-5
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
use super::rabitq::{self, Code, Lut};
2+
use crate::types::scalar8::Scalar8Owned;
23
use crate::vchordrq::types::OwnedVector;
34
use base::distance::DistanceKind;
45
use base::simd::ScalarLike;
6+
use base::vector::VectorBorrowed;
57
use base::vector::{VectOwned, VectorOwned};
68
use half::f16;
79
use rkyv::{Archive, ArchiveUnsized, CheckBytes, Deserialize, Serialize};
@@ -56,7 +58,7 @@ pub trait Vector: VectorOwned {
5658

5759
fn residual(vector: Self::Borrowed<'_>, center: Self::Borrowed<'_>) -> Self;
5860

59-
fn rabitq_fscan_preprocess(vector: Self::Borrowed<'_>) -> Lut;
61+
fn rabitq_preprocess(vector: Self::Borrowed<'_>) -> Lut;
6062

6163
fn rabitq_code(dims: u32, vector: Self::Borrowed<'_>) -> Code;
6264

@@ -129,8 +131,8 @@ impl Vector for VectOwned<f32> {
129131
Self::new(ScalarLike::vector_sub(vector.slice(), center.slice()))
130132
}
131133

132-
fn rabitq_fscan_preprocess(vector: Self::Borrowed<'_>) -> Lut {
133-
rabitq::fscan_preprocess(vector.slice())
134+
fn rabitq_preprocess(vector: Self::Borrowed<'_>) -> Lut {
135+
rabitq::preprocess(vector.slice())
134136
}
135137

136138
fn rabitq_code(dims: u32, vector: Self::Borrowed<'_>) -> Code {
@@ -212,8 +214,8 @@ impl Vector for VectOwned<f16> {
212214
Self::new(ScalarLike::vector_sub(vector.slice(), center.slice()))
213215
}
214216

215-
fn rabitq_fscan_preprocess(vector: Self::Borrowed<'_>) -> Lut {
216-
rabitq::fscan_preprocess(&f16::vector_to_f32(vector.slice()))
217+
fn rabitq_preprocess(vector: Self::Borrowed<'_>) -> Lut {
218+
rabitq::preprocess(&f16::vector_to_f32(vector.slice()))
217219
}
218220

219221
fn rabitq_code(dims: u32, vector: Self::Borrowed<'_>) -> Code {
@@ -229,6 +231,147 @@ impl Vector for VectOwned<f16> {
229231
}
230232
}
231233

234+
impl Vector for Scalar8Owned {
235+
type Metadata = (f32, f32, f32, f32);
236+
237+
type Element = u8;
238+
239+
fn metadata_from_archived(
240+
archived: &<Self::Metadata as ArchiveUnsized>::Archived,
241+
) -> Self::Metadata {
242+
(archived.0, archived.1, archived.2, archived.3)
243+
}
244+
245+
fn vector_split(vector: Self::Borrowed<'_>) -> (Self::Metadata, Vec<&[Self::Element]>) {
246+
let code = vector.code();
247+
(
248+
(
249+
vector.sum_of_x2(),
250+
vector.k(),
251+
vector.b(),
252+
vector.sum_of_code(),
253+
),
254+
match code.len() {
255+
0..=3840 => vec![code],
256+
3841..=5120 => vec![&code[..2560], &code[2560..]],
257+
5121.. => code.chunks(7680).collect(),
258+
},
259+
)
260+
}
261+
262+
fn vector_merge(metadata: Self::Metadata, slice: &[Self::Element]) -> Self {
263+
Scalar8Owned::new(
264+
metadata.0,
265+
metadata.1,
266+
metadata.2,
267+
metadata.3,
268+
slice.to_vec(),
269+
)
270+
}
271+
272+
fn from_owned(vector: OwnedVector) -> Self {
273+
match vector {
274+
OwnedVector::Scalar8(x) => x,
275+
_ => unreachable!(),
276+
}
277+
}
278+
279+
type DistanceAccumulator = (DistanceKind, u32, u32);
280+
281+
fn distance_begin(distance_kind: DistanceKind) -> Self::DistanceAccumulator {
282+
(distance_kind, 0, 0)
283+
}
284+
285+
fn distance_next(
286+
accumulator: &mut Self::DistanceAccumulator,
287+
left: &[Self::Element],
288+
right: &[Self::Element],
289+
) {
290+
match accumulator.0 {
291+
DistanceKind::L2 => accumulator.1 += base::simd::u8::reduce_sum_of_xy(left, right),
292+
DistanceKind::Dot => accumulator.1 += base::simd::u8::reduce_sum_of_xy(left, right),
293+
DistanceKind::Hamming => unreachable!(),
294+
DistanceKind::Jaccard => unreachable!(),
295+
}
296+
accumulator.2 += left.len() as u32;
297+
}
298+
299+
fn distance_end(
300+
accumulator: Self::DistanceAccumulator,
301+
(sum_of_x2_u, k_u, b_u, sum_of_code_u): Self::Metadata,
302+
(sum_of_x2_v, k_v, b_v, sum_of_code_v): Self::Metadata,
303+
) -> f32 {
304+
match accumulator.0 {
305+
DistanceKind::L2 => {
306+
let xy = k_u * k_v * accumulator.1 as f32
307+
+ b_u * b_v * accumulator.2 as f32
308+
+ k_u * b_v * sum_of_code_u
309+
+ b_u * k_v * sum_of_code_v;
310+
sum_of_x2_u + sum_of_x2_v - 2.0 * xy
311+
}
312+
DistanceKind::Dot => {
313+
let xy = k_u * k_v * accumulator.1 as f32
314+
+ b_u * b_v * accumulator.2 as f32
315+
+ k_u * b_v * sum_of_code_u
316+
+ b_u * k_v * sum_of_code_v;
317+
-xy
318+
}
319+
DistanceKind::Hamming => unreachable!(),
320+
DistanceKind::Jaccard => unreachable!(),
321+
}
322+
}
323+
324+
fn random_projection(vector: Self::Borrowed<'_>) -> Self {
325+
vector.own()
326+
}
327+
328+
fn residual(_: Self::Borrowed<'_>, _: Self::Borrowed<'_>) -> Self {
329+
unimplemented!()
330+
}
331+
332+
fn rabitq_preprocess(vector: Self::Borrowed<'_>) -> Lut {
333+
let dis_v_2 = vector.sum_of_x2();
334+
let k = vector.k() * 17.0;
335+
let b = vector.b();
336+
let qvector = vector
337+
.code()
338+
.iter()
339+
.map(|&x| ((x as u32 + 8) / 17) as u8)
340+
.collect::<Vec<_>>();
341+
let qvector_sum = if qvector.len() <= 4369 {
342+
base::simd::u8::reduce_sum_of_x_as_u16(&qvector) as f32
343+
} else {
344+
base::simd::u8::reduce_sum_of_x_as_u32(&qvector) as f32
345+
};
346+
(dis_v_2, b, k, qvector_sum, rabitq::binarize(&qvector))
347+
}
348+
349+
fn rabitq_code(dims: u32, vector: Self::Borrowed<'_>) -> Code {
350+
let dequantized = vector
351+
.code()
352+
.iter()
353+
.map(|&x| vector.k() * x as f32 + vector.b())
354+
.collect::<Vec<_>>();
355+
rabitq::code(dims, &dequantized)
356+
}
357+
358+
fn build_to_vecf32(vector: Self::Borrowed<'_>) -> Vec<f32> {
359+
vector
360+
.code()
361+
.iter()
362+
.map(|&x| vector.k() * x as f32 + vector.b())
363+
.collect()
364+
}
365+
366+
fn build_from_vecf32(x: &[f32]) -> Self {
367+
let sum_of_x2 = f32::reduce_sum_of_x2(x);
368+
let (k, b, code) =
369+
base::simd::quantize::quantize(f32::vector_to_f32_borrowed(x).as_ref(), 255.0);
370+
let sum_of_code = base::simd::u8::reduce_sum_of_x_as_u32(&code) as f32;
371+
Self::new(sum_of_x2, k, b, sum_of_code, code)
372+
}
373+
}
374+
232375
#[derive(Clone, PartialEq, Archive, Serialize, Deserialize)]
233376
#[archive(check_bytes)]
234377
pub struct MetaTuple {

0 commit comments

Comments
 (0)