|
1 | 1 | use chroma_error::{ChromaError, ErrorCodes}; |
| 2 | +use itertools::Itertools; |
2 | 3 | use serde::{ser::SerializeMap, Deserialize, Deserializer, Serialize, Serializer}; |
3 | 4 | use serde_json::{Number, Value}; |
4 | 5 | use sprs::CsVec; |
@@ -995,14 +996,70 @@ impl serde::Serialize for Where { |
995 | 996 | } |
996 | 997 | } |
997 | 998 |
|
| 999 | +impl From<bool> for Where { |
| 1000 | + fn from(value: bool) -> Self { |
| 1001 | + if value { |
| 1002 | + Where::conjunction(vec![]) |
| 1003 | + } else { |
| 1004 | + Where::disjunction(vec![]) |
| 1005 | + } |
| 1006 | + } |
| 1007 | +} |
| 1008 | + |
998 | 1009 | impl Where { |
999 | | - pub fn conjunction(children: Vec<Where>) -> Self { |
| 1010 | + pub fn conjunction(children: impl IntoIterator<Item = Where>) -> Self { |
| 1011 | + // If children.len() == 0, we will return a conjunction that is always true. |
| 1012 | + // If children.len() == 1, we will return the single child. |
| 1013 | + // Otherwise, we will return a conjunction of the children. |
| 1014 | + |
| 1015 | + let mut children: Vec<_> = children |
| 1016 | + .into_iter() |
| 1017 | + .flat_map(|expr| { |
| 1018 | + if let Where::Composite(CompositeExpression { |
| 1019 | + operator: BooleanOperator::And, |
| 1020 | + children, |
| 1021 | + }) = expr |
| 1022 | + { |
| 1023 | + return children; |
| 1024 | + } |
| 1025 | + vec![expr] |
| 1026 | + }) |
| 1027 | + .dedup() |
| 1028 | + .collect(); |
| 1029 | + |
| 1030 | + if children.len() == 1 { |
| 1031 | + return children.pop().expect("just checked len is 1"); |
| 1032 | + } |
| 1033 | + |
1000 | 1034 | Self::Composite(CompositeExpression { |
1001 | 1035 | operator: BooleanOperator::And, |
1002 | 1036 | children, |
1003 | 1037 | }) |
1004 | 1038 | } |
1005 | | - pub fn disjunction(children: Vec<Where>) -> Self { |
| 1039 | + pub fn disjunction(children: impl IntoIterator<Item = Where>) -> Self { |
| 1040 | + // If children.len() == 0, we will return a disjunction that is always false. |
| 1041 | + // If children.len() == 1, we will return the single child. |
| 1042 | + // Otherwise, we will return a disjunction of the children. |
| 1043 | + |
| 1044 | + let mut children: Vec<_> = children |
| 1045 | + .into_iter() |
| 1046 | + .flat_map(|expr| { |
| 1047 | + if let Where::Composite(CompositeExpression { |
| 1048 | + operator: BooleanOperator::Or, |
| 1049 | + children, |
| 1050 | + }) = expr |
| 1051 | + { |
| 1052 | + return children; |
| 1053 | + } |
| 1054 | + vec![expr] |
| 1055 | + }) |
| 1056 | + .dedup() |
| 1057 | + .collect(); |
| 1058 | + |
| 1059 | + if children.len() == 1 { |
| 1060 | + return children.pop().expect("just checked len is 1"); |
| 1061 | + } |
| 1062 | + |
1006 | 1063 | Self::Composite(CompositeExpression { |
1007 | 1064 | operator: BooleanOperator::Or, |
1008 | 1065 | children, |
@@ -1049,87 +1106,15 @@ impl BitAnd for Where { |
1049 | 1106 | type Output = Where; |
1050 | 1107 |
|
1051 | 1108 | fn bitand(self, rhs: Self) -> Self::Output { |
1052 | | - match self { |
1053 | | - Where::Composite(CompositeExpression { |
1054 | | - operator: BooleanOperator::And, |
1055 | | - mut children, |
1056 | | - }) => match rhs { |
1057 | | - Where::Composite(CompositeExpression { |
1058 | | - operator: BooleanOperator::And, |
1059 | | - children: rhs_children, |
1060 | | - }) => { |
1061 | | - children.extend(rhs_children); |
1062 | | - Where::Composite(CompositeExpression { |
1063 | | - operator: BooleanOperator::And, |
1064 | | - children, |
1065 | | - }) |
1066 | | - } |
1067 | | - _ => { |
1068 | | - children.push(rhs); |
1069 | | - Where::Composite(CompositeExpression { |
1070 | | - operator: BooleanOperator::And, |
1071 | | - children, |
1072 | | - }) |
1073 | | - } |
1074 | | - }, |
1075 | | - _ => match rhs { |
1076 | | - Where::Composite(CompositeExpression { |
1077 | | - operator: BooleanOperator::And, |
1078 | | - mut children, |
1079 | | - }) => { |
1080 | | - children.insert(0, self); |
1081 | | - Where::Composite(CompositeExpression { |
1082 | | - operator: BooleanOperator::And, |
1083 | | - children, |
1084 | | - }) |
1085 | | - } |
1086 | | - _ => Where::conjunction(vec![self, rhs]), |
1087 | | - }, |
1088 | | - } |
| 1109 | + Self::conjunction([self, rhs]) |
1089 | 1110 | } |
1090 | 1111 | } |
1091 | 1112 |
|
1092 | 1113 | impl BitOr for Where { |
1093 | 1114 | type Output = Where; |
1094 | 1115 |
|
1095 | 1116 | fn bitor(self, rhs: Self) -> Self::Output { |
1096 | | - match self { |
1097 | | - Where::Composite(CompositeExpression { |
1098 | | - operator: BooleanOperator::Or, |
1099 | | - mut children, |
1100 | | - }) => match rhs { |
1101 | | - Where::Composite(CompositeExpression { |
1102 | | - operator: BooleanOperator::Or, |
1103 | | - children: rhs_children, |
1104 | | - }) => { |
1105 | | - children.extend(rhs_children); |
1106 | | - Where::Composite(CompositeExpression { |
1107 | | - operator: BooleanOperator::Or, |
1108 | | - children, |
1109 | | - }) |
1110 | | - } |
1111 | | - _ => { |
1112 | | - children.push(rhs); |
1113 | | - Where::Composite(CompositeExpression { |
1114 | | - operator: BooleanOperator::Or, |
1115 | | - children, |
1116 | | - }) |
1117 | | - } |
1118 | | - }, |
1119 | | - _ => match rhs { |
1120 | | - Where::Composite(CompositeExpression { |
1121 | | - operator: BooleanOperator::Or, |
1122 | | - mut children, |
1123 | | - }) => { |
1124 | | - children.insert(0, self); |
1125 | | - Where::Composite(CompositeExpression { |
1126 | | - operator: BooleanOperator::Or, |
1127 | | - children, |
1128 | | - }) |
1129 | | - } |
1130 | | - _ => Where::disjunction(vec![self, rhs]), |
1131 | | - }, |
1132 | | - } |
| 1117 | + Self::disjunction([self, rhs]) |
1133 | 1118 | } |
1134 | 1119 | } |
1135 | 1120 |
|
@@ -1629,6 +1614,8 @@ impl TryFrom<chroma_proto::WhereDocument> for Where { |
1629 | 1614 |
|
1630 | 1615 | #[cfg(test)] |
1631 | 1616 | mod tests { |
| 1617 | + use crate::operator::Key; |
| 1618 | + |
1632 | 1619 | use super::*; |
1633 | 1620 |
|
1634 | 1621 | #[test] |
@@ -2089,4 +2076,43 @@ mod tests { |
2089 | 2076 | assert_eq!(sv.indices, vec![0, 1]); |
2090 | 2077 | assert_eq!(sv.values, vec![1.0, 2.0]); |
2091 | 2078 | } |
| 2079 | + |
| 2080 | + #[test] |
| 2081 | + fn test_simplifies_identities() { |
| 2082 | + let all: Where = true.into(); |
| 2083 | + assert_eq!(all.clone() & all.clone(), true.into()); |
| 2084 | + assert_eq!(all.clone() | all.clone(), true.into()); |
| 2085 | + |
| 2086 | + let foo = Key::field("foo").eq("bar"); |
| 2087 | + assert_eq!(foo.clone() & all.clone(), foo.clone()); |
| 2088 | + assert_eq!(all.clone() & foo.clone(), foo.clone()); |
| 2089 | + |
| 2090 | + let none: Where = false.into(); |
| 2091 | + assert_eq!(foo.clone() | none.clone(), foo.clone()); |
| 2092 | + assert_eq!(none | foo.clone(), foo); |
| 2093 | + } |
| 2094 | + |
| 2095 | + #[test] |
| 2096 | + fn test_flattens() { |
| 2097 | + let foo = Key::field("foo").eq("bar"); |
| 2098 | + let baz = Key::field("baz").eq("quux"); |
| 2099 | + |
| 2100 | + let and_nested = foo.clone() & (baz.clone() & foo.clone()); |
| 2101 | + assert_eq!( |
| 2102 | + and_nested, |
| 2103 | + Where::Composite(CompositeExpression { |
| 2104 | + operator: BooleanOperator::And, |
| 2105 | + children: vec![foo.clone(), baz.clone(), foo.clone()] |
| 2106 | + }) |
| 2107 | + ); |
| 2108 | + |
| 2109 | + let or_nested = foo.clone() | (baz.clone() | foo.clone()); |
| 2110 | + assert_eq!( |
| 2111 | + or_nested, |
| 2112 | + Where::Composite(CompositeExpression { |
| 2113 | + operator: BooleanOperator::Or, |
| 2114 | + children: vec![foo.clone(), baz.clone(), foo.clone()] |
| 2115 | + }) |
| 2116 | + ); |
| 2117 | + } |
2092 | 2118 | } |
0 commit comments