Skip to content

Commit 0b9382e

Browse files
committed
core/vector: Eliminate unwrap() calls from mod.rs
Instead of using registers_to_ref_values(), let's just pass a slice to vector functions to eliminate bunch of unwraps.
1 parent c59ae22 commit 0b9382e

File tree

2 files changed

+52
-108
lines changed

2 files changed

+52
-108
lines changed

core/vdbe/execute.rs

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@ use crate::vdbe::affinity::{apply_numeric_affinity, try_for_float, Affinity, Par
2424
use crate::vdbe::insn::InsertFlags;
2525
use crate::vdbe::value::ComparisonOp;
2626
use crate::vdbe::{registers_to_ref_values, EndStatement, StepResult, TxnCleanup};
27-
use crate::vector::{vector32_sparse, vector_concat, vector_distance_jaccard, vector_slice};
27+
use crate::vector::{
28+
vector32, vector32_sparse, vector64, vector_concat, vector_distance_cos,
29+
vector_distance_jaccard, vector_distance_l2, vector_extract, vector_slice,
30+
};
2831
use crate::{
2932
error::{
3033
LimboError, SQLITE_CONSTRAINT, SQLITE_CONSTRAINT_NOTNULL, SQLITE_CONSTRAINT_PRIMARYKEY,
@@ -66,7 +69,6 @@ use crate::{
6669
builder::CursorType,
6770
insn::{IdxInsertFlags, Insn},
6871
},
69-
vector::{vector32, vector64, vector_distance_cos, vector_distance_l2, vector_extract},
7072
};
7173

7274
use crate::{info, turso_assert, OpenFlags, Row, TransactionState, ValueRef};
@@ -5486,47 +5488,46 @@ pub fn op_function(
54865488
}
54875489
},
54885490
crate::function::Func::Vector(vector_func) => {
5489-
let values =
5490-
registers_to_ref_values(&state.registers[*start_reg..*start_reg + arg_count]);
5491+
let args = &state.registers[*start_reg..*start_reg + arg_count];
54915492
match vector_func {
54925493
VectorFunc::Vector => {
5493-
let result = vector32(values)?;
5494+
let result = vector32(args)?;
54945495
state.registers[*dest] = Register::Value(result);
54955496
}
54965497
VectorFunc::Vector32 => {
5497-
let result = vector32(values)?;
5498+
let result = vector32(args)?;
54985499
state.registers[*dest] = Register::Value(result);
54995500
}
55005501
VectorFunc::Vector32Sparse => {
5501-
let result = vector32_sparse(values)?;
5502+
let result = vector32_sparse(args)?;
55025503
state.registers[*dest] = Register::Value(result);
55035504
}
55045505
VectorFunc::Vector64 => {
5505-
let result = vector64(values)?;
5506+
let result = vector64(args)?;
55065507
state.registers[*dest] = Register::Value(result);
55075508
}
55085509
VectorFunc::VectorExtract => {
5509-
let result = vector_extract(values)?;
5510+
let result = vector_extract(args)?;
55105511
state.registers[*dest] = Register::Value(result);
55115512
}
55125513
VectorFunc::VectorDistanceCos => {
5513-
let result = vector_distance_cos(values)?;
5514+
let result = vector_distance_cos(args)?;
55145515
state.registers[*dest] = Register::Value(result);
55155516
}
55165517
VectorFunc::VectorDistanceL2 => {
5517-
let result = vector_distance_l2(values)?;
5518+
let result = vector_distance_l2(args)?;
55185519
state.registers[*dest] = Register::Value(result);
55195520
}
55205521
VectorFunc::VectorDistanceJaccard => {
5521-
let result = vector_distance_jaccard(values)?;
5522+
let result = vector_distance_jaccard(args)?;
55225523
state.registers[*dest] = Register::Value(result);
55235524
}
55245525
VectorFunc::VectorConcat => {
5525-
let result = vector_concat(values)?;
5526+
let result = vector_concat(args)?;
55265527
state.registers[*dest] = Register::Value(result);
55275528
}
55285529
VectorFunc::VectorSlice => {
5529-
let result = vector_slice(values)?;
5530+
let result = vector_slice(args)?;
55305531
state.registers[*dest] = Register::Value(result)
55315532
}
55325533
}

core/vector/mod.rs

Lines changed: 37 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use crate::types::AsValueRef;
22
use crate::types::Value;
33
use crate::types::ValueType;
4+
use crate::vdbe::Register;
45
use crate::LimboError;
56
use crate::Result;
67
use crate::ValueRef;
@@ -33,75 +34,50 @@ pub fn parse_vector<'a>(
3334
}
3435
}
3536

36-
pub fn vector32<I, E, V>(args: I) -> Result<Value>
37-
where
38-
V: AsValueRef,
39-
E: ExactSizeIterator<Item = V>,
40-
I: IntoIterator<IntoIter = E, Item = V>,
41-
{
42-
let mut args = args.into_iter();
37+
pub fn vector32(args: &[Register]) -> Result<Value> {
4338
if args.len() != 1 {
4439
return Err(LimboError::ConversionError(
4540
"vector32 requires exactly one argument".to_string(),
4641
));
4742
}
48-
let value = args.next().unwrap();
49-
let vector = parse_vector(&value, Some(VectorType::Float32Dense))?;
43+
let value = args[0].get_value();
44+
let vector = parse_vector(value, Some(VectorType::Float32Dense))?;
5045
let vector = operations::convert::vector_convert(vector, VectorType::Float32Dense)?;
5146
Ok(operations::serialize::vector_serialize(vector))
5247
}
5348

54-
pub fn vector32_sparse<I, E, V>(args: I) -> Result<Value>
55-
where
56-
V: AsValueRef,
57-
E: ExactSizeIterator<Item = V>,
58-
I: IntoIterator<IntoIter = E, Item = V>,
59-
{
60-
let mut args = args.into_iter();
49+
pub fn vector32_sparse(args: &[Register]) -> Result<Value> {
6150
if args.len() != 1 {
6251
return Err(LimboError::ConversionError(
6352
"vector32_sparse requires exactly one argument".to_string(),
6453
));
6554
}
66-
let value = args.next().unwrap();
67-
let vector = parse_vector(&value, Some(VectorType::Float32Sparse))?;
55+
let value = args[0].get_value();
56+
let vector = parse_vector(value, Some(VectorType::Float32Sparse))?;
6857
let vector = operations::convert::vector_convert(vector, VectorType::Float32Sparse)?;
6958
Ok(operations::serialize::vector_serialize(vector))
7059
}
7160

72-
pub fn vector64<I, E, V>(args: I) -> Result<Value>
73-
where
74-
V: AsValueRef,
75-
E: ExactSizeIterator<Item = V>,
76-
I: IntoIterator<IntoIter = E, Item = V>,
77-
{
78-
let mut args = args.into_iter();
61+
pub fn vector64(args: &[Register]) -> Result<Value> {
7962
if args.len() != 1 {
8063
return Err(LimboError::ConversionError(
8164
"vector64 requires exactly one argument".to_string(),
8265
));
8366
}
84-
let value = args.next().unwrap();
85-
let vector = parse_vector(&value, Some(VectorType::Float64Dense))?;
67+
let value = args[0].get_value();
68+
let vector = parse_vector(value, Some(VectorType::Float64Dense))?;
8669
let vector = operations::convert::vector_convert(vector, VectorType::Float64Dense)?;
8770
Ok(operations::serialize::vector_serialize(vector))
8871
}
8972

90-
pub fn vector_extract<I, E, V>(args: I) -> Result<Value>
91-
where
92-
V: AsValueRef,
93-
E: ExactSizeIterator<Item = V>,
94-
I: IntoIterator<IntoIter = E, Item = V>,
95-
{
96-
let mut args = args.into_iter();
73+
pub fn vector_extract(args: &[Register]) -> Result<Value> {
9774
if args.len() != 1 {
9875
return Err(LimboError::ConversionError(
9976
"vector_extract requires exactly one argument".to_string(),
10077
));
10178
}
10279

103-
let value = args.next().unwrap();
104-
let value = value.as_value_ref();
80+
let value = args[0].get_value().as_value_ref();
10581
let blob = match value {
10682
ValueRef::Blob(b) => b,
10783
_ => {
@@ -119,110 +95,77 @@ where
11995
Ok(Value::build_text(operations::text::vector_to_text(&vector)))
12096
}
12197

122-
pub fn vector_distance_cos<I, E, V>(args: I) -> Result<Value>
123-
where
124-
V: AsValueRef,
125-
E: ExactSizeIterator<Item = V>,
126-
I: IntoIterator<IntoIter = E, Item = V>,
127-
{
128-
let mut args = args.into_iter();
98+
pub fn vector_distance_cos(args: &[Register]) -> Result<Value> {
12999
if args.len() != 2 {
130100
return Err(LimboError::ConversionError(
131101
"vector_distance_cos requires exactly two arguments".to_string(),
132102
));
133103
}
134104

135-
let value_0 = args.next().unwrap();
136-
let value_1 = args.next().unwrap();
137-
let x = parse_vector(&value_0, None)?;
138-
let y = parse_vector(&value_1, None)?;
105+
let value_0 = args[0].get_value();
106+
let value_1 = args[1].get_value();
107+
let x = parse_vector(value_0, None)?;
108+
let y = parse_vector(value_1, None)?;
139109
let dist = operations::distance_cos::vector_distance_cos(&x, &y)?;
140110
Ok(Value::Float(dist))
141111
}
142112

143-
pub fn vector_distance_l2<I, E, V>(args: I) -> Result<Value>
144-
where
145-
V: AsValueRef,
146-
E: ExactSizeIterator<Item = V>,
147-
I: IntoIterator<IntoIter = E, Item = V>,
148-
{
149-
let mut args = args.into_iter();
113+
pub fn vector_distance_l2(args: &[Register]) -> Result<Value> {
150114
if args.len() != 2 {
151115
return Err(LimboError::ConversionError(
152116
"distance_l2 requires exactly two arguments".to_string(),
153117
));
154118
}
155119

156-
let value_0 = args.next().unwrap();
157-
let value_1 = args.next().unwrap();
158-
let x = parse_vector(&value_0, None)?;
159-
let y = parse_vector(&value_1, None)?;
120+
let value_0 = args[0].get_value();
121+
let value_1 = args[1].get_value();
122+
let x = parse_vector(value_0, None)?;
123+
let y = parse_vector(value_1, None)?;
160124
let dist = operations::distance_l2::vector_distance_l2(&x, &y)?;
161125
Ok(Value::Float(dist))
162126
}
163127

164-
pub fn vector_distance_jaccard<I, E, V>(args: I) -> Result<Value>
165-
where
166-
V: AsValueRef,
167-
E: ExactSizeIterator<Item = V>,
168-
I: IntoIterator<IntoIter = E, Item = V>,
169-
{
170-
let mut args = args.into_iter();
128+
pub fn vector_distance_jaccard(args: &[Register]) -> Result<Value> {
171129
if args.len() != 2 {
172130
return Err(LimboError::ConversionError(
173131
"distance_jaccard requires exactly two arguments".to_string(),
174132
));
175133
}
176134

177-
let value_0 = args.next().unwrap();
178-
let value_1 = args.next().unwrap();
179-
let x = parse_vector(&value_0, None)?;
180-
let y = parse_vector(&value_1, None)?;
135+
let value_0 = args[0].get_value();
136+
let value_1 = args[1].get_value();
137+
let x = parse_vector(value_0, None)?;
138+
let y = parse_vector(value_1, None)?;
181139
let dist = operations::jaccard::vector_distance_jaccard(&x, &y)?;
182140
Ok(Value::Float(dist))
183141
}
184142

185-
pub fn vector_concat<I, E, V>(args: I) -> Result<Value>
186-
where
187-
V: AsValueRef,
188-
E: ExactSizeIterator<Item = V>,
189-
I: IntoIterator<IntoIter = E, Item = V>,
190-
{
191-
let mut args = args.into_iter();
143+
pub fn vector_concat(args: &[Register]) -> Result<Value> {
192144
if args.len() != 2 {
193145
return Err(LimboError::InvalidArgument(
194146
"concat requires exactly two arguments".into(),
195147
));
196148
}
197149

198-
let value_0 = args.next().unwrap();
199-
let value_1 = args.next().unwrap();
200-
let x = parse_vector(&value_0, None)?;
201-
let y = parse_vector(&value_1, None)?;
150+
let value_0 = args[0].get_value();
151+
let value_1 = args[1].get_value();
152+
let x = parse_vector(value_0, None)?;
153+
let y = parse_vector(value_1, None)?;
202154
let vector = operations::concat::vector_concat(&x, &y)?;
203155
Ok(operations::serialize::vector_serialize(vector))
204156
}
205157

206-
pub fn vector_slice<I, E, V>(args: I) -> Result<Value>
207-
where
208-
V: AsValueRef,
209-
E: ExactSizeIterator<Item = V>,
210-
I: IntoIterator<IntoIter = E, Item = V>,
211-
{
212-
let mut args = args.into_iter();
158+
pub fn vector_slice(args: &[Register]) -> Result<Value> {
213159
if args.len() != 3 {
214160
return Err(LimboError::InvalidArgument(
215161
"vector_slice requires exactly three arguments".into(),
216162
));
217163
}
218-
let value_0 = args.next().unwrap();
219-
let value_1 = args.next().unwrap();
220-
let value_1 = value_1.as_value_ref();
221-
222-
let value_2 = args.next().unwrap();
223-
let value_2 = value_2.as_value_ref();
164+
let value_0 = args[0].get_value();
165+
let value_1 = args[1].get_value().as_value_ref();
166+
let value_2 = args[2].get_value().as_value_ref();
224167

225-
let vector = parse_vector(&value_0, None)?;
168+
let vector = parse_vector(value_0, None)?;
226169

227170
let start_index = value_1
228171
.as_int()

0 commit comments

Comments
 (0)