diff --git a/src/lib.rs b/src/lib.rs index bab5815..ddc0cfe 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -371,6 +371,78 @@ mod tests { assert_eq!(result, Ok(4)); } + #[test] + fn test_retain_unsync() { + let mut cache = unsync::Cache::::new(100); + let ranges = 0..10; + for i in ranges.clone() { + let guard = cache.get_ref_or_guard(&i).unwrap_err(); + guard.insert(i); + assert_eq!(cache.get_ref_or_guard(&i).ok().copied(), Some(i)); + } + let small = 3; + cache.retain(|&key, &val| val > small && key > small); + for i in ranges.clone() { + let actual = cache.get(&i); + if i > small { + assert!(actual.is_some()); + assert_eq!(*actual.unwrap(), i); + } else { + assert!(actual.is_none()); + } + } + let big = 7; + cache.retain(|&key, &val| val < big && key < big); + for i in ranges { + let actual = cache.get(&i); + if i > small && i < big { + assert!(actual.is_some()); + assert_eq!(*actual.unwrap(), i); + } else { + assert!(actual.is_none()); + } + } + } + + #[tokio::test] + async fn test_retain_sync() { + use crate::sync::*; + let cache = Cache::::new(100); + let ranges = 0..10; + for i in ranges.clone() { + let GuardResult::Guard(guard) = cache.get_value_or_guard(&i, None) else { + panic!(); + }; + guard.insert(i).unwrap(); + let GuardResult::Value(v) = cache.get_value_or_guard(&i, None) else { + panic!(); + }; + assert_eq!(v, i); + } + let small = 4; + cache.retain(|&key, &val| val > small && key > small); + for i in ranges.clone() { + let actual = cache.get(&i); + if i > small { + assert!(actual.is_some()); + assert_eq!(actual.unwrap(), i); + } else { + assert!(actual.is_none()); + } + } + let big = 8; + cache.retain(|&key, &val| val < big && key < big); + for i in ranges { + let actual = cache.get(&i); + if i > small && i < big { + assert!(actual.is_some()); + assert_eq!(actual.unwrap(), i); + } else { + assert!(actual.is_none()); + } + } + } + #[test] #[cfg_attr(miri, ignore)] fn test_value_or_guard() { diff --git a/src/shard.rs b/src/shard.rs index f780696..65fc148 100644 --- a/src/shard.rs +++ b/src/shard.rs @@ -332,6 +332,33 @@ impl< }) } + pub fn retain(&mut self, f: F) + where + F: Fn(&Key, &Val) -> bool, + { + let retained_tokens = self + .map + .iter() + .filter_map(|&idx| match self.entries.get(idx) { + Some((entry, _idx)) => match entry { + Entry::Resident(r) => { + if !f(&r.key, &r.value) { + let hash = self.hash(&r.key); + Some((idx, hash)) + } else { + None + } + } + Entry::Placeholder(_) | Entry::Ghost(_) => None, + }, + None => None, + }) + .collect::>(); + for (idx, hash) in retained_tokens { + self.remove_internal(hash, idx); + } + } + pub fn weight(&self) -> u64 { self.weight_hot + self.weight_cold } diff --git a/src/sync.rs b/src/sync.rs index f8b50c0..1443446 100644 --- a/src/sync.rs +++ b/src/sync.rs @@ -287,6 +287,18 @@ impl< Ok(lcs) } + /// Retains only the items specified by the predicate. + /// In other words, remove all items for which `f(&key, &value)` returns `false`. The + /// elements are visited in arbitrary order. + pub fn retain(&self, f: F) + where + F: Fn(&Key, &Val) -> bool, + { + for s in self.shards.iter() { + s.write().retain(&f); + } + } + /// Inserts an item in the cache with key `key`. pub fn insert(&self, key: Key, value: Val) { let lcs = self.insert_with_lifecycle(key, value); diff --git a/src/unsync.rs b/src/unsync.rs index f84c0a1..0aded9c 100644 --- a/src/unsync.rs +++ b/src/unsync.rs @@ -227,6 +227,16 @@ impl, B: BuildHasher, L: Lifecycle(&mut self, f: F) + where + F: Fn(&Key, &Val) -> bool, + { + self.shard.retain(f); + } + /// Gets or inserts an item in the cache with key `key`. /// Returns a reference to the inserted `value` if it was admitted to the cache. ///