Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 3 additions & 49 deletions crates/rabitq/src/extended.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,14 @@ pub fn ugly_code<const BITS: usize>(vector: &[f32]) -> Code {
f32::vector_mul_scalar_inplace(&mut vector, 1.0 / dis_u_2.sqrt());
vector
};
let (scale, delta) = {
let scale = {
let mut o = normalized_vector.clone();
f32::vector_abs_inplace(&mut o);
ugly_find_scale::<BITS>(&o)
};
let mut code = Vec::with_capacity(n as _);
for i in 0..n {
let v = scale * normalized_vector[i];
let v = v + v.signum() * delta[i] as f32;
let c = v.floor().clamp(min as f32, max as f32) as i32;
code.push((c + (1 << (BITS - 1))) as _);
}
Expand Down Expand Up @@ -251,55 +250,10 @@ fn find_scale<const B: usize>(o: &[f32]) -> f32 {
x_m as f32 + f32::EPSILON
}

fn ugly_find_scale<const B: usize>(o: &[f32]) -> (f32, Vec<i32>) {
fn ugly_find_scale<const B: usize>(o: &[f32]) -> f32 {
assert!((1..=8).contains(&B));

let dim = o.len();

let mut code = Vec::<u8>::with_capacity(dim);
let mut numerator_m = 0.0f64;
let mut sqr_denominator_m = 0.0f64;

let scale = (1 << (B - 1)) as f32 / f32::reduce_min_max_of_x(o).1;
for i in 0..dim {
code.push((o[i] as f64 * scale as f64) as u8);
let value = code[i] as f64 + 0.5;
numerator_m += value * o[i] as f64;
sqr_denominator_m += value * value;
}
let mut y_m = numerator_m / sqr_denominator_m.sqrt();

let mut delta = vec![0_i32; dim];
for _ in 0..8 {
for i in 0..dim {
if code[i] < (1 << (B - 1)) - 1 {
let numerator = numerator_m + o[i] as f64;
let sqr_denominator = sqr_denominator_m + 2.0 * (code[i] as f64 + 1.0);
let y = numerator / sqr_denominator.sqrt();
if y > y_m {
y_m = y;
numerator_m = numerator;
sqr_denominator_m = sqr_denominator;
code[i] += 1;
delta[i] += 1;
}
}
if code[i] > 0 {
let numerator = numerator_m - o[i] as f64;
let sqr_denominator = sqr_denominator_m - 2.0 * code[i] as f64;
let y = numerator / sqr_denominator.sqrt();
if y > y_m {
y_m = y;
numerator_m = numerator;
sqr_denominator_m = sqr_denominator;
code[i] -= 1;
delta[i] -= 1;
}
}
}
}

(scale, delta)
(1 << (B - 1)) as f32 / f32::reduce_min_max_of_x(o).1
}

pub fn pack_code<const BITS: usize>(input: &[u8]) -> [Vec<u64>; BITS] {
Expand Down
125 changes: 85 additions & 40 deletions tests/vchordg/index_vector_rabitq.slt
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@ CREATE INDEX ti ON t USING vchordg ((quantize_to_rabitq8(val)::rabitq8(64)) rabi
query I
SELECT index FROM t ORDER BY quantize_to_rabitq8(val) <-> quantize_to_rabitq8(array_cat(ARRAY[0.6, 0.8], ARRAY(SELECT 0.0 FROM generate_series(1, 62)))::vector) LIMIT 10;
----
155
1608
1643
155
174
818
1603
1629
60
218
1080
1568

statement ok
DROP INDEX ti;
Expand All @@ -40,16 +40,16 @@ CREATE INDEX ti ON t USING vchordg ((quantize_to_rabitq8(val)::rabitq8(64)) rabi
query I
SELECT index FROM t ORDER BY quantize_to_rabitq8(val) <#> quantize_to_rabitq8(array_cat(ARRAY[0.6, 0.8], ARRAY(SELECT 0.0 FROM generate_series(1, 62)))::vector) LIMIT 10;
----
155
1608
1643
155
174
818
1603
1629
60
218
1080
1639

statement ok
DROP INDEX ti;
Expand All @@ -60,16 +60,16 @@ CREATE INDEX ti ON t USING vchordg ((quantize_to_rabitq8(val)::rabitq8(64)) rabi
query I
SELECT index FROM t ORDER BY quantize_to_rabitq8(val) <=> quantize_to_rabitq8(array_cat(ARRAY[0.6, 0.8], ARRAY(SELECT 0.0 FROM generate_series(1, 62)))::vector) LIMIT 10;
----
155
1608
1643
155
174
818
1603
1629
60
218
1080
1568

statement ok
DROP INDEX ti;
Expand All @@ -78,17 +78,18 @@ statement ok
CREATE INDEX ti ON t USING vchordg ((quantize_to_rabitq8(val::halfvec)::rabitq8(64)) rabitq8_l2_ops);

query I
SELECT index FROM t ORDER BY quantize_to_rabitq8(val::halfvec) <-> quantize_to_rabitq8(array_cat(ARRAY[0.6, 0.8], ARRAY(SELECT 0.0 FROM generate_series(1, 62)))::halfvec) LIMIT 9;
SELECT index FROM t ORDER BY quantize_to_rabitq8(val::halfvec) <-> quantize_to_rabitq8(array_cat(ARRAY[0.6, 0.8], ARRAY(SELECT 0.0 FROM generate_series(1, 62)))::halfvec) LIMIT 10;
----
155
1608
1643
155
174
818
1603
818
1629
60
218
1080
79

statement ok
DROP INDEX ti;
Expand All @@ -97,17 +98,18 @@ statement ok
CREATE INDEX ti ON t USING vchordg ((quantize_to_rabitq8(val::halfvec)::rabitq8(64)) rabitq8_ip_ops);

query I
SELECT index FROM t ORDER BY quantize_to_rabitq8(val::halfvec) <#> quantize_to_rabitq8(array_cat(ARRAY[0.6, 0.8], ARRAY(SELECT 0.0 FROM generate_series(1, 62)))::halfvec) LIMIT 9;
SELECT index FROM t ORDER BY quantize_to_rabitq8(val::halfvec) <#> quantize_to_rabitq8(array_cat(ARRAY[0.6, 0.8], ARRAY(SELECT 0.0 FROM generate_series(1, 62)))::halfvec) LIMIT 10;
----
155
1608
1643
155
174
818
1603
1629
60
218
1080
79

statement ok
DROP INDEX ti;
Expand All @@ -116,17 +118,18 @@ statement ok
CREATE INDEX ti ON t USING vchordg ((quantize_to_rabitq8(val::halfvec)::rabitq8(64)) rabitq8_cosine_ops);

query I
SELECT index FROM t ORDER BY quantize_to_rabitq8(val::halfvec) <=> quantize_to_rabitq8(array_cat(ARRAY[0.6, 0.8], ARRAY(SELECT 0.0 FROM generate_series(1, 62)))::halfvec) LIMIT 9;
SELECT index FROM t ORDER BY quantize_to_rabitq8(val::halfvec) <=> quantize_to_rabitq8(array_cat(ARRAY[0.6, 0.8], ARRAY(SELECT 0.0 FROM generate_series(1, 62)))::halfvec) LIMIT 10;
----
155
1608
1643
155
174
818
1603
1629
60
218
1080
79

statement ok
DROP INDEX ti;
Expand All @@ -135,11 +138,18 @@ statement ok
CREATE INDEX ti ON t USING vchordg ((quantize_to_rabitq4(val)::rabitq4(64)) rabitq4_l2_ops);

query I
SELECT (SUM(index) = 6162 OR SUM(index) = 6289)::int FROM (
SELECT index FROM t ORDER BY quantize_to_rabitq4(val) <-> quantize_to_rabitq4(array_cat(ARRAY[0.6, 0.8], ARRAY(SELECT 0.0 FROM generate_series(1, 62)))::vector) LIMIT 10
) AS s(index);
SELECT index FROM t ORDER BY quantize_to_rabitq4(val) <-> quantize_to_rabitq4(array_cat(ARRAY[0.6, 0.8], ARRAY(SELECT 0.0 FROM generate_series(1, 62)))::vector) LIMIT 10;
----
1
155
1643
174
818
1207
1608
1940
331
1568
1839

statement ok
DROP INDEX ti;
Expand All @@ -148,11 +158,18 @@ statement ok
CREATE INDEX ti ON t USING vchordg ((quantize_to_rabitq4(val)::rabitq4(64)) rabitq4_ip_ops);

query I
SELECT (SUM(index) = 6162 OR SUM(index) = 6289)::int FROM (
SELECT index FROM t ORDER BY quantize_to_rabitq4(val) <#> quantize_to_rabitq4(array_cat(ARRAY[0.6, 0.8], ARRAY(SELECT 0.0 FROM generate_series(1, 62)))::vector) LIMIT 10
) AS s(index);
SELECT index FROM t ORDER BY quantize_to_rabitq4(val) <#> quantize_to_rabitq4(array_cat(ARRAY[0.6, 0.8], ARRAY(SELECT 0.0 FROM generate_series(1, 62)))::vector) LIMIT 10;
----
1
155
1643
174
818
1207
1608
1940
331
1568
1839

statement ok
DROP INDEX ti;
Expand All @@ -161,11 +178,18 @@ statement ok
CREATE INDEX ti ON t USING vchordg ((quantize_to_rabitq4(val)::rabitq4(64)) rabitq4_cosine_ops);

query I
SELECT (SUM(index) = 6162 OR SUM(index) = 6289)::int FROM (
SELECT index FROM t ORDER BY quantize_to_rabitq4(val) <=> quantize_to_rabitq4(array_cat(ARRAY[0.6, 0.8], ARRAY(SELECT 0.0 FROM generate_series(1, 62)))::vector) LIMIT 10
) AS s(index);
SELECT index FROM t ORDER BY quantize_to_rabitq4(val) <=> quantize_to_rabitq4(array_cat(ARRAY[0.6, 0.8], ARRAY(SELECT 0.0 FROM generate_series(1, 62)))::vector) LIMIT 10;
----
1
155
1643
174
818
1207
1608
1940
331
1568
1839

statement ok
DROP INDEX ti;
Expand All @@ -174,11 +198,18 @@ statement ok
CREATE INDEX ti ON t USING vchordg ((quantize_to_rabitq4(val::halfvec)::rabitq4(64)) rabitq4_l2_ops);

query I
SELECT SUM(index) FROM (
SELECT index FROM t ORDER BY quantize_to_rabitq4(val::halfvec) <-> quantize_to_rabitq4(array_cat(ARRAY[0.6, 0.8], ARRAY(SELECT 0.0 FROM generate_series(1, 62)))::halfvec) LIMIT 10
) AS s(index);
SELECT index FROM t ORDER BY quantize_to_rabitq4(val::halfvec) <-> quantize_to_rabitq4(array_cat(ARRAY[0.6, 0.8], ARRAY(SELECT 0.0 FROM generate_series(1, 62)))::halfvec) LIMIT 10;
----
7290
155
1643
174
818
1207
1608
1940
331
1839
1568

statement ok
DROP INDEX ti;
Expand All @@ -187,11 +218,18 @@ statement ok
CREATE INDEX ti ON t USING vchordg ((quantize_to_rabitq4(val::halfvec)::rabitq4(64)) rabitq4_ip_ops);

query I
SELECT SUM(index) FROM (
SELECT index FROM t ORDER BY quantize_to_rabitq4(val::halfvec) <#> quantize_to_rabitq4(array_cat(ARRAY[0.6, 0.8], ARRAY(SELECT 0.0 FROM generate_series(1, 62)))::halfvec) LIMIT 10
) AS s(index);
SELECT index FROM t ORDER BY quantize_to_rabitq4(val::halfvec) <#> quantize_to_rabitq4(array_cat(ARRAY[0.6, 0.8], ARRAY(SELECT 0.0 FROM generate_series(1, 62)))::halfvec) LIMIT 10;
----
7290
155
1643
174
818
1207
1608
1940
331
1568
1839

statement ok
DROP INDEX ti;
Expand All @@ -200,11 +238,18 @@ statement ok
CREATE INDEX ti ON t USING vchordg ((quantize_to_rabitq4(val::halfvec)::rabitq4(64)) rabitq4_cosine_ops);

query I
SELECT SUM(index) FROM (
SELECT index FROM t ORDER BY quantize_to_rabitq4(val::halfvec) <=> quantize_to_rabitq4(array_cat(ARRAY[0.6, 0.8], ARRAY(SELECT 0.0 FROM generate_series(1, 62)))::halfvec) LIMIT 10
) AS s(index);
SELECT index FROM t ORDER BY quantize_to_rabitq4(val::halfvec) <=> quantize_to_rabitq4(array_cat(ARRAY[0.6, 0.8], ARRAY(SELECT 0.0 FROM generate_series(1, 62)))::halfvec) LIMIT 10;
----
7290
155
1643
174
818
1207
1608
1940
331
1568
1839

statement ok
DROP INDEX ti;
Expand Down
8 changes: 4 additions & 4 deletions tests/vchordrq/index_multivector_rabitq.slt
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,13 @@ LIMIT 10;
1207
919
1821
1076
1639
537
174
1125
79
1163
537
194
1920
239

statement ok
DROP INDEX pg_temp.ti;
Expand Down
Loading
Loading