Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 15 additions & 14 deletions core/vdbe/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ use crate::vdbe::affinity::{apply_numeric_affinity, try_for_float, Affinity, Par
use crate::vdbe::insn::InsertFlags;
use crate::vdbe::value::ComparisonOp;
use crate::vdbe::{registers_to_ref_values, EndStatement, StepResult, TxnCleanup};
use crate::vector::{vector32_sparse, vector_concat, vector_distance_jaccard, vector_slice};
use crate::vector::{
vector32, vector32_sparse, vector64, vector_concat, vector_distance_cos,
vector_distance_jaccard, vector_distance_l2, vector_extract, vector_slice,
};
use crate::{
error::{
LimboError, SQLITE_CONSTRAINT, SQLITE_CONSTRAINT_NOTNULL, SQLITE_CONSTRAINT_PRIMARYKEY,
Expand Down Expand Up @@ -66,7 +69,6 @@ use crate::{
builder::CursorType,
insn::{IdxInsertFlags, Insn},
},
vector::{vector32, vector64, vector_distance_cos, vector_distance_l2, vector_extract},
};

use crate::{info, turso_assert, OpenFlags, Row, TransactionState, ValueRef};
Expand Down Expand Up @@ -5486,47 +5488,46 @@ pub fn op_function(
}
},
crate::function::Func::Vector(vector_func) => {
let values =
registers_to_ref_values(&state.registers[*start_reg..*start_reg + arg_count]);
let args = &state.registers[*start_reg..*start_reg + arg_count];
match vector_func {
VectorFunc::Vector => {
let result = vector32(values)?;
let result = vector32(args)?;
state.registers[*dest] = Register::Value(result);
}
VectorFunc::Vector32 => {
let result = vector32(values)?;
let result = vector32(args)?;
state.registers[*dest] = Register::Value(result);
}
VectorFunc::Vector32Sparse => {
let result = vector32_sparse(values)?;
let result = vector32_sparse(args)?;
state.registers[*dest] = Register::Value(result);
}
VectorFunc::Vector64 => {
let result = vector64(values)?;
let result = vector64(args)?;
state.registers[*dest] = Register::Value(result);
}
VectorFunc::VectorExtract => {
let result = vector_extract(values)?;
let result = vector_extract(args)?;
state.registers[*dest] = Register::Value(result);
}
VectorFunc::VectorDistanceCos => {
let result = vector_distance_cos(values)?;
let result = vector_distance_cos(args)?;
state.registers[*dest] = Register::Value(result);
}
VectorFunc::VectorDistanceL2 => {
let result = vector_distance_l2(values)?;
let result = vector_distance_l2(args)?;
state.registers[*dest] = Register::Value(result);
}
VectorFunc::VectorDistanceJaccard => {
let result = vector_distance_jaccard(values)?;
let result = vector_distance_jaccard(args)?;
state.registers[*dest] = Register::Value(result);
}
VectorFunc::VectorConcat => {
let result = vector_concat(values)?;
let result = vector_concat(args)?;
state.registers[*dest] = Register::Value(result);
}
VectorFunc::VectorSlice => {
let result = vector_slice(values)?;
let result = vector_slice(args)?;
state.registers[*dest] = Register::Value(result)
}
}
Expand Down
131 changes: 37 additions & 94 deletions core/vector/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::types::AsValueRef;
use crate::types::Value;
use crate::types::ValueType;
use crate::vdbe::Register;
use crate::LimboError;
use crate::Result;
use crate::ValueRef;
Expand Down Expand Up @@ -33,75 +34,50 @@ pub fn parse_vector<'a>(
}
}

pub fn vector32<I, E, V>(args: I) -> Result<Value>
where
V: AsValueRef,
E: ExactSizeIterator<Item = V>,
I: IntoIterator<IntoIter = E, Item = V>,
{
let mut args = args.into_iter();
pub fn vector32(args: &[Register]) -> Result<Value> {
if args.len() != 1 {
return Err(LimboError::ConversionError(
"vector32 requires exactly one argument".to_string(),
));
}
let value = args.next().unwrap();
let vector = parse_vector(&value, Some(VectorType::Float32Dense))?;
let value = args[0].get_value();
let vector = parse_vector(value, Some(VectorType::Float32Dense))?;
let vector = operations::convert::vector_convert(vector, VectorType::Float32Dense)?;
Ok(operations::serialize::vector_serialize(vector))
}

pub fn vector32_sparse<I, E, V>(args: I) -> Result<Value>
where
V: AsValueRef,
E: ExactSizeIterator<Item = V>,
I: IntoIterator<IntoIter = E, Item = V>,
{
let mut args = args.into_iter();
pub fn vector32_sparse(args: &[Register]) -> Result<Value> {
if args.len() != 1 {
return Err(LimboError::ConversionError(
"vector32_sparse requires exactly one argument".to_string(),
));
}
let value = args.next().unwrap();
let vector = parse_vector(&value, Some(VectorType::Float32Sparse))?;
let value = args[0].get_value();
let vector = parse_vector(value, Some(VectorType::Float32Sparse))?;
let vector = operations::convert::vector_convert(vector, VectorType::Float32Sparse)?;
Ok(operations::serialize::vector_serialize(vector))
}

pub fn vector64<I, E, V>(args: I) -> Result<Value>
where
V: AsValueRef,
E: ExactSizeIterator<Item = V>,
I: IntoIterator<IntoIter = E, Item = V>,
{
let mut args = args.into_iter();
pub fn vector64(args: &[Register]) -> Result<Value> {
if args.len() != 1 {
return Err(LimboError::ConversionError(
"vector64 requires exactly one argument".to_string(),
));
}
let value = args.next().unwrap();
let vector = parse_vector(&value, Some(VectorType::Float64Dense))?;
let value = args[0].get_value();
let vector = parse_vector(value, Some(VectorType::Float64Dense))?;
let vector = operations::convert::vector_convert(vector, VectorType::Float64Dense)?;
Ok(operations::serialize::vector_serialize(vector))
}

pub fn vector_extract<I, E, V>(args: I) -> Result<Value>
where
V: AsValueRef,
E: ExactSizeIterator<Item = V>,
I: IntoIterator<IntoIter = E, Item = V>,
{
let mut args = args.into_iter();
pub fn vector_extract(args: &[Register]) -> Result<Value> {
if args.len() != 1 {
return Err(LimboError::ConversionError(
"vector_extract requires exactly one argument".to_string(),
));
}

let value = args.next().unwrap();
let value = value.as_value_ref();
let value = args[0].get_value().as_value_ref();
let blob = match value {
ValueRef::Blob(b) => b,
_ => {
Expand All @@ -119,110 +95,77 @@ where
Ok(Value::build_text(operations::text::vector_to_text(&vector)))
}

pub fn vector_distance_cos<I, E, V>(args: I) -> Result<Value>
where
V: AsValueRef,
E: ExactSizeIterator<Item = V>,
I: IntoIterator<IntoIter = E, Item = V>,
{
let mut args = args.into_iter();
pub fn vector_distance_cos(args: &[Register]) -> Result<Value> {
if args.len() != 2 {
return Err(LimboError::ConversionError(
"vector_distance_cos requires exactly two arguments".to_string(),
));
}

let value_0 = args.next().unwrap();
let value_1 = args.next().unwrap();
let x = parse_vector(&value_0, None)?;
let y = parse_vector(&value_1, None)?;
let value_0 = args[0].get_value();
let value_1 = args[1].get_value();
let x = parse_vector(value_0, None)?;
let y = parse_vector(value_1, None)?;
let dist = operations::distance_cos::vector_distance_cos(&x, &y)?;
Ok(Value::Float(dist))
}

pub fn vector_distance_l2<I, E, V>(args: I) -> Result<Value>
where
V: AsValueRef,
E: ExactSizeIterator<Item = V>,
I: IntoIterator<IntoIter = E, Item = V>,
{
let mut args = args.into_iter();
pub fn vector_distance_l2(args: &[Register]) -> Result<Value> {
if args.len() != 2 {
return Err(LimboError::ConversionError(
"distance_l2 requires exactly two arguments".to_string(),
));
}

let value_0 = args.next().unwrap();
let value_1 = args.next().unwrap();
let x = parse_vector(&value_0, None)?;
let y = parse_vector(&value_1, None)?;
let value_0 = args[0].get_value();
let value_1 = args[1].get_value();
let x = parse_vector(value_0, None)?;
let y = parse_vector(value_1, None)?;
let dist = operations::distance_l2::vector_distance_l2(&x, &y)?;
Ok(Value::Float(dist))
}

pub fn vector_distance_jaccard<I, E, V>(args: I) -> Result<Value>
where
V: AsValueRef,
E: ExactSizeIterator<Item = V>,
I: IntoIterator<IntoIter = E, Item = V>,
{
let mut args = args.into_iter();
pub fn vector_distance_jaccard(args: &[Register]) -> Result<Value> {
if args.len() != 2 {
return Err(LimboError::ConversionError(
"distance_jaccard requires exactly two arguments".to_string(),
));
}

let value_0 = args.next().unwrap();
let value_1 = args.next().unwrap();
let x = parse_vector(&value_0, None)?;
let y = parse_vector(&value_1, None)?;
let value_0 = args[0].get_value();
let value_1 = args[1].get_value();
let x = parse_vector(value_0, None)?;
let y = parse_vector(value_1, None)?;
let dist = operations::jaccard::vector_distance_jaccard(&x, &y)?;
Ok(Value::Float(dist))
}

pub fn vector_concat<I, E, V>(args: I) -> Result<Value>
where
V: AsValueRef,
E: ExactSizeIterator<Item = V>,
I: IntoIterator<IntoIter = E, Item = V>,
{
let mut args = args.into_iter();
pub fn vector_concat(args: &[Register]) -> Result<Value> {
if args.len() != 2 {
return Err(LimboError::InvalidArgument(
"concat requires exactly two arguments".into(),
));
}

let value_0 = args.next().unwrap();
let value_1 = args.next().unwrap();
let x = parse_vector(&value_0, None)?;
let y = parse_vector(&value_1, None)?;
let value_0 = args[0].get_value();
let value_1 = args[1].get_value();
let x = parse_vector(value_0, None)?;
let y = parse_vector(value_1, None)?;
let vector = operations::concat::vector_concat(&x, &y)?;
Ok(operations::serialize::vector_serialize(vector))
}

pub fn vector_slice<I, E, V>(args: I) -> Result<Value>
where
V: AsValueRef,
E: ExactSizeIterator<Item = V>,
I: IntoIterator<IntoIter = E, Item = V>,
{
let mut args = args.into_iter();
pub fn vector_slice(args: &[Register]) -> Result<Value> {
if args.len() != 3 {
return Err(LimboError::InvalidArgument(
"vector_slice requires exactly three arguments".into(),
));
}
let value_0 = args.next().unwrap();
let value_1 = args.next().unwrap();
let value_1 = value_1.as_value_ref();

let value_2 = args.next().unwrap();
let value_2 = value_2.as_value_ref();
let value_0 = args[0].get_value();
let value_1 = args[1].get_value().as_value_ref();
let value_2 = args[2].get_value().as_value_ref();

let vector = parse_vector(&value_0, None)?;
let vector = parse_vector(value_0, None)?;

let start_index = value_1
.as_int()
Expand Down
25 changes: 20 additions & 5 deletions core/vector/vector_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,9 @@ impl<'a> Vector<'a> {
) -> Result<Self> {
let owned_slice = owned.as_deref();
let refer_slice = refer.as_ref().map(|&x| x);
let data = owned_slice.unwrap_or_else(|| refer_slice.unwrap());
let data = owned_slice.or(refer_slice).ok_or_else(|| {
LimboError::InternalError("Vector must have either owned or refer data".to_string())
})?;
match vector_type {
VectorType::Float32Dense => {
if data.len() % 4 != 0 {
Expand Down Expand Up @@ -167,7 +169,12 @@ impl<'a> Vector<'a> {
}
let original_len = data.len();
let dims_bytes = &data[original_len - 4..];
let dims = u32::from_le_bytes(dims_bytes.try_into().unwrap()) as usize;
let dims = u32::from_le_bytes([
dims_bytes[0],
dims_bytes[1],
dims_bytes[2],
dims_bytes[3],
]) as usize;
let owned = owned.map(|mut x| {
x.truncate(original_len - 4);
x
Expand All @@ -187,17 +194,25 @@ impl<'a> Vector<'a> {
pub fn bin_len(&self) -> usize {
let owned = self.owned.as_ref().map(|x| x.len());
let refer = self.refer.as_ref().map(|x| x.len());
owned.unwrap_or_else(|| refer.unwrap())
owned
.or(refer)
.expect("Vector invariant: exactly one of owned or refer must be Some")
}

pub fn bin_data(&'a self) -> &'a [u8] {
let owned = self.owned.as_deref();
let refer = self.refer.as_ref().map(|&x| x);
owned.unwrap_or_else(|| refer.unwrap())
owned
.or(refer)
.expect("Vector invariant: exactly one of owned or refer must be Some")
}

pub fn bin_eject(self) -> Vec<u8> {
self.owned.unwrap_or_else(|| self.refer.unwrap().to_vec())
self.owned.unwrap_or_else(|| {
self.refer
.expect("Vector invariant: exactly one of owned or refer must be Some")
.to_vec()
})
}

/// # Safety
Expand Down
Loading