Skip to content

Commit 34ebde8

Browse files
authored
[ENH]: (Rust client): add true.into::<Where>() helper (#5750)
1 parent 0d3bd2a commit 34ebde8

File tree

3 files changed

+104
-76
lines changed

3 files changed

+104
-76
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

rust/types/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ proptest = { workspace = true, optional = true }
3434
proptest-derive = { workspace = true, optional = true }
3535

3636
chroma-error = { workspace = true, features = ["tonic", "validator"] }
37+
itertools.workspace = true
3738

3839
[build-dependencies]
3940
tonic-build = "0.10"

rust/types/src/metadata.rs

Lines changed: 102 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use chroma_error::{ChromaError, ErrorCodes};
2+
use itertools::Itertools;
23
use serde::{ser::SerializeMap, Deserialize, Deserializer, Serialize, Serializer};
34
use serde_json::{Number, Value};
45
use sprs::CsVec;
@@ -995,14 +996,70 @@ impl serde::Serialize for Where {
995996
}
996997
}
997998

999+
impl From<bool> for Where {
1000+
fn from(value: bool) -> Self {
1001+
if value {
1002+
Where::conjunction(vec![])
1003+
} else {
1004+
Where::disjunction(vec![])
1005+
}
1006+
}
1007+
}
1008+
9981009
impl Where {
999-
pub fn conjunction(children: Vec<Where>) -> Self {
1010+
pub fn conjunction(children: impl IntoIterator<Item = Where>) -> Self {
1011+
// If children.len() == 0, we will return a conjunction that is always true.
1012+
// If children.len() == 1, we will return the single child.
1013+
// Otherwise, we will return a conjunction of the children.
1014+
1015+
let mut children: Vec<_> = children
1016+
.into_iter()
1017+
.flat_map(|expr| {
1018+
if let Where::Composite(CompositeExpression {
1019+
operator: BooleanOperator::And,
1020+
children,
1021+
}) = expr
1022+
{
1023+
return children;
1024+
}
1025+
vec![expr]
1026+
})
1027+
.dedup()
1028+
.collect();
1029+
1030+
if children.len() == 1 {
1031+
return children.pop().expect("just checked len is 1");
1032+
}
1033+
10001034
Self::Composite(CompositeExpression {
10011035
operator: BooleanOperator::And,
10021036
children,
10031037
})
10041038
}
1005-
pub fn disjunction(children: Vec<Where>) -> Self {
1039+
pub fn disjunction(children: impl IntoIterator<Item = Where>) -> Self {
1040+
// If children.len() == 0, we will return a disjunction that is always false.
1041+
// If children.len() == 1, we will return the single child.
1042+
// Otherwise, we will return a disjunction of the children.
1043+
1044+
let mut children: Vec<_> = children
1045+
.into_iter()
1046+
.flat_map(|expr| {
1047+
if let Where::Composite(CompositeExpression {
1048+
operator: BooleanOperator::Or,
1049+
children,
1050+
}) = expr
1051+
{
1052+
return children;
1053+
}
1054+
vec![expr]
1055+
})
1056+
.dedup()
1057+
.collect();
1058+
1059+
if children.len() == 1 {
1060+
return children.pop().expect("just checked len is 1");
1061+
}
1062+
10061063
Self::Composite(CompositeExpression {
10071064
operator: BooleanOperator::Or,
10081065
children,
@@ -1049,87 +1106,15 @@ impl BitAnd for Where {
10491106
type Output = Where;
10501107

10511108
fn bitand(self, rhs: Self) -> Self::Output {
1052-
match self {
1053-
Where::Composite(CompositeExpression {
1054-
operator: BooleanOperator::And,
1055-
mut children,
1056-
}) => match rhs {
1057-
Where::Composite(CompositeExpression {
1058-
operator: BooleanOperator::And,
1059-
children: rhs_children,
1060-
}) => {
1061-
children.extend(rhs_children);
1062-
Where::Composite(CompositeExpression {
1063-
operator: BooleanOperator::And,
1064-
children,
1065-
})
1066-
}
1067-
_ => {
1068-
children.push(rhs);
1069-
Where::Composite(CompositeExpression {
1070-
operator: BooleanOperator::And,
1071-
children,
1072-
})
1073-
}
1074-
},
1075-
_ => match rhs {
1076-
Where::Composite(CompositeExpression {
1077-
operator: BooleanOperator::And,
1078-
mut children,
1079-
}) => {
1080-
children.insert(0, self);
1081-
Where::Composite(CompositeExpression {
1082-
operator: BooleanOperator::And,
1083-
children,
1084-
})
1085-
}
1086-
_ => Where::conjunction(vec![self, rhs]),
1087-
},
1088-
}
1109+
Self::conjunction([self, rhs])
10891110
}
10901111
}
10911112

10921113
impl BitOr for Where {
10931114
type Output = Where;
10941115

10951116
fn bitor(self, rhs: Self) -> Self::Output {
1096-
match self {
1097-
Where::Composite(CompositeExpression {
1098-
operator: BooleanOperator::Or,
1099-
mut children,
1100-
}) => match rhs {
1101-
Where::Composite(CompositeExpression {
1102-
operator: BooleanOperator::Or,
1103-
children: rhs_children,
1104-
}) => {
1105-
children.extend(rhs_children);
1106-
Where::Composite(CompositeExpression {
1107-
operator: BooleanOperator::Or,
1108-
children,
1109-
})
1110-
}
1111-
_ => {
1112-
children.push(rhs);
1113-
Where::Composite(CompositeExpression {
1114-
operator: BooleanOperator::Or,
1115-
children,
1116-
})
1117-
}
1118-
},
1119-
_ => match rhs {
1120-
Where::Composite(CompositeExpression {
1121-
operator: BooleanOperator::Or,
1122-
mut children,
1123-
}) => {
1124-
children.insert(0, self);
1125-
Where::Composite(CompositeExpression {
1126-
operator: BooleanOperator::Or,
1127-
children,
1128-
})
1129-
}
1130-
_ => Where::disjunction(vec![self, rhs]),
1131-
},
1132-
}
1117+
Self::disjunction([self, rhs])
11331118
}
11341119
}
11351120

@@ -1629,6 +1614,8 @@ impl TryFrom<chroma_proto::WhereDocument> for Where {
16291614

16301615
#[cfg(test)]
16311616
mod tests {
1617+
use crate::operator::Key;
1618+
16321619
use super::*;
16331620

16341621
#[test]
@@ -2089,4 +2076,43 @@ mod tests {
20892076
assert_eq!(sv.indices, vec![0, 1]);
20902077
assert_eq!(sv.values, vec![1.0, 2.0]);
20912078
}
2079+
2080+
#[test]
2081+
fn test_simplifies_identities() {
2082+
let all: Where = true.into();
2083+
assert_eq!(all.clone() & all.clone(), true.into());
2084+
assert_eq!(all.clone() | all.clone(), true.into());
2085+
2086+
let foo = Key::field("foo").eq("bar");
2087+
assert_eq!(foo.clone() & all.clone(), foo.clone());
2088+
assert_eq!(all.clone() & foo.clone(), foo.clone());
2089+
2090+
let none: Where = false.into();
2091+
assert_eq!(foo.clone() | none.clone(), foo.clone());
2092+
assert_eq!(none | foo.clone(), foo);
2093+
}
2094+
2095+
#[test]
2096+
fn test_flattens() {
2097+
let foo = Key::field("foo").eq("bar");
2098+
let baz = Key::field("baz").eq("quux");
2099+
2100+
let and_nested = foo.clone() & (baz.clone() & foo.clone());
2101+
assert_eq!(
2102+
and_nested,
2103+
Where::Composite(CompositeExpression {
2104+
operator: BooleanOperator::And,
2105+
children: vec![foo.clone(), baz.clone(), foo.clone()]
2106+
})
2107+
);
2108+
2109+
let or_nested = foo.clone() | (baz.clone() | foo.clone());
2110+
assert_eq!(
2111+
or_nested,
2112+
Where::Composite(CompositeExpression {
2113+
operator: BooleanOperator::Or,
2114+
children: vec![foo.clone(), baz.clone(), foo.clone()]
2115+
})
2116+
);
2117+
}
20922118
}

0 commit comments

Comments
 (0)