diff --git a/core/vdbe/execute.rs b/core/vdbe/execute.rs index 90ddd75117..1846f1ff81 100644 --- a/core/vdbe/execute.rs +++ b/core/vdbe/execute.rs @@ -24,7 +24,10 @@ use crate::vdbe::affinity::{apply_numeric_affinity, try_for_float, Affinity, Par use crate::vdbe::insn::InsertFlags; use crate::vdbe::value::ComparisonOp; use crate::vdbe::{registers_to_ref_values, EndStatement, StepResult, TxnCleanup}; -use crate::vector::{vector32_sparse, vector_concat, vector_distance_jaccard, vector_slice}; +use crate::vector::{ + vector32, vector32_sparse, vector64, vector_concat, vector_distance_cos, + vector_distance_jaccard, vector_distance_l2, vector_extract, vector_slice, +}; use crate::{ error::{ LimboError, SQLITE_CONSTRAINT, SQLITE_CONSTRAINT_NOTNULL, SQLITE_CONSTRAINT_PRIMARYKEY, @@ -66,7 +69,6 @@ use crate::{ builder::CursorType, insn::{IdxInsertFlags, Insn}, }, - vector::{vector32, vector64, vector_distance_cos, vector_distance_l2, vector_extract}, }; use crate::{info, turso_assert, OpenFlags, Row, TransactionState, ValueRef}; @@ -5486,47 +5488,46 @@ pub fn op_function( } }, crate::function::Func::Vector(vector_func) => { - let values = - registers_to_ref_values(&state.registers[*start_reg..*start_reg + arg_count]); + let args = &state.registers[*start_reg..*start_reg + arg_count]; match vector_func { VectorFunc::Vector => { - let result = vector32(values)?; + let result = vector32(args)?; state.registers[*dest] = Register::Value(result); } VectorFunc::Vector32 => { - let result = vector32(values)?; + let result = vector32(args)?; state.registers[*dest] = Register::Value(result); } VectorFunc::Vector32Sparse => { - let result = vector32_sparse(values)?; + let result = vector32_sparse(args)?; state.registers[*dest] = Register::Value(result); } VectorFunc::Vector64 => { - let result = vector64(values)?; + let result = vector64(args)?; state.registers[*dest] = Register::Value(result); } VectorFunc::VectorExtract => { - let result = vector_extract(values)?; + let result = vector_extract(args)?; state.registers[*dest] = Register::Value(result); } VectorFunc::VectorDistanceCos => { - let result = vector_distance_cos(values)?; + let result = vector_distance_cos(args)?; state.registers[*dest] = Register::Value(result); } VectorFunc::VectorDistanceL2 => { - let result = vector_distance_l2(values)?; + let result = vector_distance_l2(args)?; state.registers[*dest] = Register::Value(result); } VectorFunc::VectorDistanceJaccard => { - let result = vector_distance_jaccard(values)?; + let result = vector_distance_jaccard(args)?; state.registers[*dest] = Register::Value(result); } VectorFunc::VectorConcat => { - let result = vector_concat(values)?; + let result = vector_concat(args)?; state.registers[*dest] = Register::Value(result); } VectorFunc::VectorSlice => { - let result = vector_slice(values)?; + let result = vector_slice(args)?; state.registers[*dest] = Register::Value(result) } } diff --git a/core/vector/mod.rs b/core/vector/mod.rs index 31038564b1..17f2be1316 100644 --- a/core/vector/mod.rs +++ b/core/vector/mod.rs @@ -1,6 +1,7 @@ use crate::types::AsValueRef; use crate::types::Value; use crate::types::ValueType; +use crate::vdbe::Register; use crate::LimboError; use crate::Result; use crate::ValueRef; @@ -33,75 +34,50 @@ pub fn parse_vector<'a>( } } -pub fn vector32(args: I) -> Result -where - V: AsValueRef, - E: ExactSizeIterator, - I: IntoIterator, -{ - let mut args = args.into_iter(); +pub fn vector32(args: &[Register]) -> Result { if args.len() != 1 { return Err(LimboError::ConversionError( "vector32 requires exactly one argument".to_string(), )); } - let value = args.next().unwrap(); - let vector = parse_vector(&value, Some(VectorType::Float32Dense))?; + let value = args[0].get_value(); + let vector = parse_vector(value, Some(VectorType::Float32Dense))?; let vector = operations::convert::vector_convert(vector, VectorType::Float32Dense)?; Ok(operations::serialize::vector_serialize(vector)) } -pub fn vector32_sparse(args: I) -> Result -where - V: AsValueRef, - E: ExactSizeIterator, - I: IntoIterator, -{ - let mut args = args.into_iter(); +pub fn vector32_sparse(args: &[Register]) -> Result { if args.len() != 1 { return Err(LimboError::ConversionError( "vector32_sparse requires exactly one argument".to_string(), )); } - let value = args.next().unwrap(); - let vector = parse_vector(&value, Some(VectorType::Float32Sparse))?; + let value = args[0].get_value(); + let vector = parse_vector(value, Some(VectorType::Float32Sparse))?; let vector = operations::convert::vector_convert(vector, VectorType::Float32Sparse)?; Ok(operations::serialize::vector_serialize(vector)) } -pub fn vector64(args: I) -> Result -where - V: AsValueRef, - E: ExactSizeIterator, - I: IntoIterator, -{ - let mut args = args.into_iter(); +pub fn vector64(args: &[Register]) -> Result { if args.len() != 1 { return Err(LimboError::ConversionError( "vector64 requires exactly one argument".to_string(), )); } - let value = args.next().unwrap(); - let vector = parse_vector(&value, Some(VectorType::Float64Dense))?; + let value = args[0].get_value(); + let vector = parse_vector(value, Some(VectorType::Float64Dense))?; let vector = operations::convert::vector_convert(vector, VectorType::Float64Dense)?; Ok(operations::serialize::vector_serialize(vector)) } -pub fn vector_extract(args: I) -> Result -where - V: AsValueRef, - E: ExactSizeIterator, - I: IntoIterator, -{ - let mut args = args.into_iter(); +pub fn vector_extract(args: &[Register]) -> Result { if args.len() != 1 { return Err(LimboError::ConversionError( "vector_extract requires exactly one argument".to_string(), )); } - let value = args.next().unwrap(); - let value = value.as_value_ref(); + let value = args[0].get_value().as_value_ref(); let blob = match value { ValueRef::Blob(b) => b, _ => { @@ -119,110 +95,77 @@ where Ok(Value::build_text(operations::text::vector_to_text(&vector))) } -pub fn vector_distance_cos(args: I) -> Result -where - V: AsValueRef, - E: ExactSizeIterator, - I: IntoIterator, -{ - let mut args = args.into_iter(); +pub fn vector_distance_cos(args: &[Register]) -> Result { if args.len() != 2 { return Err(LimboError::ConversionError( "vector_distance_cos requires exactly two arguments".to_string(), )); } - let value_0 = args.next().unwrap(); - let value_1 = args.next().unwrap(); - let x = parse_vector(&value_0, None)?; - let y = parse_vector(&value_1, None)?; + let value_0 = args[0].get_value(); + let value_1 = args[1].get_value(); + let x = parse_vector(value_0, None)?; + let y = parse_vector(value_1, None)?; let dist = operations::distance_cos::vector_distance_cos(&x, &y)?; Ok(Value::Float(dist)) } -pub fn vector_distance_l2(args: I) -> Result -where - V: AsValueRef, - E: ExactSizeIterator, - I: IntoIterator, -{ - let mut args = args.into_iter(); +pub fn vector_distance_l2(args: &[Register]) -> Result { if args.len() != 2 { return Err(LimboError::ConversionError( "distance_l2 requires exactly two arguments".to_string(), )); } - let value_0 = args.next().unwrap(); - let value_1 = args.next().unwrap(); - let x = parse_vector(&value_0, None)?; - let y = parse_vector(&value_1, None)?; + let value_0 = args[0].get_value(); + let value_1 = args[1].get_value(); + let x = parse_vector(value_0, None)?; + let y = parse_vector(value_1, None)?; let dist = operations::distance_l2::vector_distance_l2(&x, &y)?; Ok(Value::Float(dist)) } -pub fn vector_distance_jaccard(args: I) -> Result -where - V: AsValueRef, - E: ExactSizeIterator, - I: IntoIterator, -{ - let mut args = args.into_iter(); +pub fn vector_distance_jaccard(args: &[Register]) -> Result { if args.len() != 2 { return Err(LimboError::ConversionError( "distance_jaccard requires exactly two arguments".to_string(), )); } - let value_0 = args.next().unwrap(); - let value_1 = args.next().unwrap(); - let x = parse_vector(&value_0, None)?; - let y = parse_vector(&value_1, None)?; + let value_0 = args[0].get_value(); + let value_1 = args[1].get_value(); + let x = parse_vector(value_0, None)?; + let y = parse_vector(value_1, None)?; let dist = operations::jaccard::vector_distance_jaccard(&x, &y)?; Ok(Value::Float(dist)) } -pub fn vector_concat(args: I) -> Result -where - V: AsValueRef, - E: ExactSizeIterator, - I: IntoIterator, -{ - let mut args = args.into_iter(); +pub fn vector_concat(args: &[Register]) -> Result { if args.len() != 2 { return Err(LimboError::InvalidArgument( "concat requires exactly two arguments".into(), )); } - let value_0 = args.next().unwrap(); - let value_1 = args.next().unwrap(); - let x = parse_vector(&value_0, None)?; - let y = parse_vector(&value_1, None)?; + let value_0 = args[0].get_value(); + let value_1 = args[1].get_value(); + let x = parse_vector(value_0, None)?; + let y = parse_vector(value_1, None)?; let vector = operations::concat::vector_concat(&x, &y)?; Ok(operations::serialize::vector_serialize(vector)) } -pub fn vector_slice(args: I) -> Result -where - V: AsValueRef, - E: ExactSizeIterator, - I: IntoIterator, -{ - let mut args = args.into_iter(); +pub fn vector_slice(args: &[Register]) -> Result { if args.len() != 3 { return Err(LimboError::InvalidArgument( "vector_slice requires exactly three arguments".into(), )); } - let value_0 = args.next().unwrap(); - let value_1 = args.next().unwrap(); - let value_1 = value_1.as_value_ref(); - - let value_2 = args.next().unwrap(); - let value_2 = value_2.as_value_ref(); + let value_0 = args[0].get_value(); + let value_1 = args[1].get_value().as_value_ref(); + let value_2 = args[2].get_value().as_value_ref(); - let vector = parse_vector(&value_0, None)?; + let vector = parse_vector(value_0, None)?; let start_index = value_1 .as_int() diff --git a/core/vector/vector_types.rs b/core/vector/vector_types.rs index 8b70fdbda2..8186119a94 100644 --- a/core/vector/vector_types.rs +++ b/core/vector/vector_types.rs @@ -128,7 +128,9 @@ impl<'a> Vector<'a> { ) -> Result { let owned_slice = owned.as_deref(); let refer_slice = refer.as_ref().map(|&x| x); - let data = owned_slice.unwrap_or_else(|| refer_slice.unwrap()); + let data = owned_slice.or(refer_slice).ok_or_else(|| { + LimboError::InternalError("Vector must have either owned or refer data".to_string()) + })?; match vector_type { VectorType::Float32Dense => { if data.len() % 4 != 0 { @@ -167,7 +169,12 @@ impl<'a> Vector<'a> { } let original_len = data.len(); let dims_bytes = &data[original_len - 4..]; - let dims = u32::from_le_bytes(dims_bytes.try_into().unwrap()) as usize; + let dims = u32::from_le_bytes([ + dims_bytes[0], + dims_bytes[1], + dims_bytes[2], + dims_bytes[3], + ]) as usize; let owned = owned.map(|mut x| { x.truncate(original_len - 4); x @@ -187,17 +194,25 @@ impl<'a> Vector<'a> { pub fn bin_len(&self) -> usize { let owned = self.owned.as_ref().map(|x| x.len()); let refer = self.refer.as_ref().map(|x| x.len()); - owned.unwrap_or_else(|| refer.unwrap()) + owned + .or(refer) + .expect("Vector invariant: exactly one of owned or refer must be Some") } pub fn bin_data(&'a self) -> &'a [u8] { let owned = self.owned.as_deref(); let refer = self.refer.as_ref().map(|&x| x); - owned.unwrap_or_else(|| refer.unwrap()) + owned + .or(refer) + .expect("Vector invariant: exactly one of owned or refer must be Some") } pub fn bin_eject(self) -> Vec { - self.owned.unwrap_or_else(|| self.refer.unwrap().to_vec()) + self.owned.unwrap_or_else(|| { + self.refer + .expect("Vector invariant: exactly one of owned or refer must be Some") + .to_vec() + }) } /// # Safety