Skip to content

Commit 59810e4

Browse files
committed
refactor ClusterMap to never hand out dashmap references
1 parent 4ba1d36 commit 59810e4

File tree

6 files changed

+115
-82
lines changed

6 files changed

+115
-82
lines changed

benches/cluster_map.rs

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,19 @@ mod serde {
1717
fn serialize_to_protobuf(cm: &ClusterMap) -> Vec<Any> {
1818
let mut resources = Vec::new();
1919

20-
for cluster in cm.iter() {
21-
resources.push(
22-
Resource::Cluster(Cluster {
23-
locality: cluster.key().clone().map(From::from),
24-
endpoints: cluster
25-
.endpoints
26-
.iter()
27-
.map(TryFrom::try_from)
28-
.collect::<Result<_, _>>()
29-
.unwrap(),
30-
})
31-
.try_encode()
32-
.unwrap(),
33-
);
34-
}
20+
cm.iter_with(|locality, endpoint_set| {
21+
Resource::Cluster(Cluster {
22+
locality: locality.clone().map(From::from),
23+
endpoints: endpoint_set
24+
.endpoints
25+
.iter()
26+
.map(TryFrom::try_from)
27+
.collect::<Result<_, _>>()
28+
.unwrap(),
29+
})
30+
})
31+
.into_iter()
32+
.map(|resource| resources.push(resource.try_encode().unwrap()));
3533

3634
resources
3735
}
@@ -111,11 +109,11 @@ mod ops {
111109
use shared::{GenCluster, gen_cluster_map};
112110

113111
fn compute_hash<const S: u64>(gc: &GenCluster) -> usize {
114-
let mut total_endpoints = 0;
115-
116-
for kv in gc.cm.iter() {
117-
total_endpoints += kv.endpoints.len();
118-
}
112+
let total_endpoints = gc
113+
.cm
114+
.iter_with(|_locality, endpoint_set| endpoint_set.len())
115+
.iter()
116+
.fold(0, |acc, e| acc + e);
119117

120118
assert_eq!(total_endpoints, gc.total_endpoints);
121119
total_endpoints

benches/shared.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -690,7 +690,7 @@ pub fn gen_cluster_map<const S: u64>(token_kind: TokenKind) -> GenCluster {
690690

691691
// Now actually insert the endpoints, now that the order of keys is established,
692692
// annoying, but note we split out iteration versus insertion, otherwise we deadlock
693-
let keys: Vec<_> = cm.iter().map(|kv| kv.key().clone()).collect();
693+
let keys = cm.iter_with(|locality, _endpoint_set| locality.clone());
694694
let mut sets = std::collections::BTreeMap::new();
695695

696696
let mut token_generator = match token_kind {

benches/token_router.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,18 @@ fn token_router(b: Bencher<'_, '_>, token_kind: &str) {
1414
let cm = std::sync::Arc::new(gc.cm);
1515

1616
// Calculate the amount of bytes for all the tokens
17-
for eps in cm.iter() {
18-
for ep in &eps.value().endpoints {
17+
cm.iter_with(|_locality, endpoint_set| {
18+
let mut tokens = Vec::new();
19+
for ep in &endpoint_set.endpoints {
1920
for tok in ep.metadata.known.tokens.iter() {
2021
tokens.push(tok.clone());
2122
}
2223
}
23-
}
24+
tokens
25+
})
26+
.iter()
27+
.flatten()
28+
.map(|tok| tokens.push(tok));
2429

2530
let total_token_size: usize = tokens.iter().map(|t| t.len()).sum();
2631
let pool = std::sync::Arc::new(quilkin::collections::BufferPool::new(1, 1));

src/config.rs

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -644,40 +644,43 @@ impl Config {
644644
}
645645
}
646646
ResourceType::Cluster => {
647-
let mut push = |key: &Option<crate::net::endpoint::Locality>,
648-
value: &crate::net::cluster::EndpointSet|
649-
-> crate::Result<()> {
650-
let version = value.version().to_string();
651-
let key_s = key.as_ref().map(|k| k.to_string()).unwrap_or_default();
652-
653-
if client_state.version_matches(&key_s, &version) {
654-
return Ok(());
655-
}
656-
657-
let resource = crate::xds::Resource::Cluster(
658-
quilkin_xds::generated::quilkin::config::v1alpha1::Cluster {
659-
locality: key.clone().map(|l| l.into()),
660-
endpoints: value.endpoints.iter().map(|ep| ep.into()).collect(),
661-
},
662-
);
647+
let cluster_resource_transformer =
648+
|key: &Option<crate::net::endpoint::Locality>,
649+
value: &crate::net::cluster::EndpointSet|
650+
-> crate::Result<Option<XdsResource>> {
651+
let version = value.version().to_string();
652+
let key_s = key.as_ref().map(|k| k.to_string()).unwrap_or_default();
653+
654+
if client_state.version_matches(&key_s, &version) {
655+
return Ok(None);
656+
}
663657

664-
resources.push(XdsResource {
665-
name: key_s,
666-
version,
667-
resource: Some(resource.try_encode()?),
668-
..Default::default()
669-
});
658+
let resource = crate::xds::Resource::Cluster(
659+
quilkin_xds::generated::quilkin::config::v1alpha1::Cluster {
660+
locality: key.clone().map(|l| l.into()),
661+
endpoints: value.endpoints.iter().map(|ep| ep.into()).collect(),
662+
},
663+
);
670664

671-
Ok(())
672-
};
665+
Ok(Some(XdsResource {
666+
name: key_s,
667+
version,
668+
resource: Some(resource.try_encode()?),
669+
..Default::default()
670+
}))
671+
};
673672

674673
let Some(clusters) = self.dyn_cfg.clusters() else {
675674
break 'append;
676675
};
677676

678677
if client_state.subscribed.is_empty() {
679-
for cluster in clusters.read().iter() {
680-
push(cluster.key(), cluster.value())?;
678+
for cluster_resources in
679+
clusters.read().iter_with(cluster_resource_transformer)
680+
{
681+
if let Some(clr) = cluster_resources? {
682+
resources.push(clr);
683+
}
681684
}
682685
} else {
683686
for locality in client_state.subscribed.iter().filter_map(|name| {
@@ -687,18 +690,22 @@ impl Config {
687690
name.parse().ok().map(Some)
688691
}
689692
}) {
690-
if let Some(cluster) = clusters.read().get(&locality) {
691-
push(cluster.key(), cluster.value())?;
693+
if let Some(cluster_resource) =
694+
clusters.read().with_value(&locality, |entry| {
695+
cluster_resource_transformer(entry.key(), entry.value())
696+
})
697+
{
698+
if let Some(clr) = cluster_resource? {
699+
resources.push(clr);
700+
}
692701
}
693702
}
694703
};
695704

696705
// Currently, we have exactly _one_ special case for removed resources, which
697706
// is when ClusterMap::update_unlocated_endpoints is called to move the None
698707
// locality endpoints to another one, so we just detect that case manually
699-
if client_state.versions.contains_key("")
700-
&& clusters.read().get(&None).is_none()
701-
{
708+
if client_state.versions.contains_key("") && !clusters.read().exists(&None) {
702709
removed.insert("".into());
703710
}
704711
}

src/metrics.rs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -676,14 +676,13 @@ pub(crate) fn apply_clusters(clusters: &crate::config::Watch<crate::net::Cluster
676676
let clusters = clusters.read();
677677
crate::net::cluster::active_clusters().set(clusters.len() as i64);
678678

679-
for entry in clusters.iter() {
679+
clusters.iter_with(|locality, endpoint_set| {
680680
crate::net::cluster::active_endpoints(
681-
&entry
682-
.key()
681+
&locality
683682
.clone()
684683
.map(|key| key.to_string())
685684
.unwrap_or_default(),
686685
)
687-
.set(entry.value().len() as i64);
688-
}
686+
.set(endpoint_set.len() as i64);
687+
});
689688
}

src/net/cluster.rs

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -263,8 +263,6 @@ pub struct ClusterMap<S = gxhash::GxBuildHasher> {
263263
version: AtomicU64,
264264
}
265265

266-
type DashMapRef<'inner> = dashmap::mapref::one::Ref<'inner, Option<Locality>, EndpointSet>;
267-
268266
impl ClusterMap {
269267
pub fn new() -> Self {
270268
Self::default()
@@ -367,15 +365,39 @@ where
367365
}
368366

369367
#[inline]
370-
pub fn get(&self, key: &Option<Locality>) -> Option<DashMapRef<'_>> {
371-
self.map.get(key)
368+
pub fn exists(&self, key: &Option<Locality>) -> bool {
369+
self.map.get(key).is_some()
370+
}
371+
372+
#[inline]
373+
pub fn with_value<F, T>(&self, key: &Option<Locality>, func: F) -> Option<T>
374+
where
375+
F: FnOnce(dashmap::mapref::one::Ref<'_, Option<Locality>, EndpointSet>) -> T,
376+
{
377+
self.map.get(key).map(func)
372378
}
373379

374380
#[inline]
375381
pub fn insert_default(&self, endpoints: BTreeSet<Endpoint>) {
376382
self.insert(None, None, endpoints);
377383
}
378384

385+
/// Iterates over the entries in the `ClusterMap` with the given func
386+
///
387+
/// This ensures that the dashmap entry references are never held across await boundaries as
388+
/// the func cannot be async.
389+
#[inline]
390+
pub fn iter_with<F, T>(&self, func: F) -> Vec<T>
391+
where
392+
F: for<'a> Fn(&'a Option<Locality>, &'a EndpointSet) -> T,
393+
{
394+
let mut results: Vec<T> = Vec::new();
395+
for entry in self.map.iter() {
396+
results.push(func(entry.key(), entry.value()));
397+
}
398+
results
399+
}
400+
379401
#[inline]
380402
pub fn remove_endpoint(&self, needle: &Endpoint) -> bool {
381403
let locality = 'l: {
@@ -430,11 +452,6 @@ where
430452
true
431453
}
432454

433-
#[inline]
434-
pub fn iter(&self) -> dashmap::iter::Iter<'_, Option<Locality>, EndpointSet, S> {
435-
self.map.iter()
436-
}
437-
438455
#[inline]
439456
pub fn replace(
440457
&self,
@@ -477,7 +494,7 @@ where
477494
}
478495

479496
pub fn nth_endpoint(&self, mut index: usize) -> Option<Endpoint> {
480-
for set in self.iter() {
497+
for set in self.map.iter() {
481498
let set = &set.value().endpoints;
482499
if index < set.len() {
483500
return set.iter().nth(index).cloned();
@@ -492,7 +509,7 @@ where
492509
pub fn filter_endpoints(&self, f: impl Fn(&Endpoint) -> bool) -> Vec<Endpoint> {
493510
let mut endpoints = Vec::new();
494511

495-
for set in self.iter() {
512+
for set in self.map.iter() {
496513
for endpoint in set.endpoints.iter().filter(|e| (f)(e)) {
497514
endpoints.push(endpoint.clone());
498515
}
@@ -639,8 +656,9 @@ where
639656
S: Default + std::hash::BuildHasher + Clone,
640657
{
641658
fn eq(&self, rhs: &Self) -> bool {
642-
for a in self.iter() {
659+
for a in self.map.iter() {
643660
match rhs
661+
.map
644662
.get(a.key())
645663
.filter(|b| a.value().endpoints == b.endpoints)
646664
{
@@ -808,16 +826,18 @@ mod tests {
808826
cluster1.insert(None, Some(nl1.clone()), [endpoint.clone()].into());
809827
cluster1.insert(None, Some(de1.clone()), [endpoint.clone()].into());
810828

811-
assert_eq!(cluster1.get(&Some(nl1.clone())).unwrap().len(), 1);
829+
assert_eq!(cluster1.map.get(&Some(nl1.clone())).unwrap().len(), 1);
812830
assert!(
813831
cluster1
832+
.map
814833
.get(&Some(nl1.clone()))
815834
.unwrap()
816835
.contains(&endpoint)
817836
);
818-
assert_eq!(cluster1.get(&Some(de1.clone())).unwrap().len(), 1);
837+
assert_eq!(cluster1.map.get(&Some(de1.clone())).unwrap().len(), 1);
819838
assert!(
820839
cluster1
840+
.map
821841
.get(&Some(de1.clone()))
822842
.unwrap()
823843
.contains(&endpoint)
@@ -827,19 +847,20 @@ mod tests {
827847

828848
cluster1.insert(None, Some(de1.clone()), [endpoint.clone()].into());
829849

830-
assert_eq!(cluster1.get(&Some(nl1.clone())).unwrap().len(), 1);
831-
assert_eq!(cluster1.get(&Some(de1.clone())).unwrap().len(), 1);
850+
assert_eq!(cluster1.map.get(&Some(nl1.clone())).unwrap().len(), 1);
851+
assert_eq!(cluster1.map.get(&Some(de1.clone())).unwrap().len(), 1);
832852
assert!(
833853
cluster1
854+
.map
834855
.get(&Some(de1.clone()))
835856
.unwrap()
836857
.contains(&endpoint)
837858
);
838859

839860
cluster1.insert(None, Some(de1.clone()), <_>::default());
840861

841-
assert_eq!(cluster1.get(&Some(nl1.clone())).unwrap().len(), 1);
842-
assert!(cluster1.get(&Some(de1.clone())).unwrap().is_empty());
862+
assert_eq!(cluster1.map.get(&Some(nl1.clone())).unwrap().len(), 1);
863+
assert!(cluster1.map.get(&Some(de1.clone())).unwrap().is_empty());
843864
}
844865

845866
#[test]
@@ -862,12 +883,15 @@ mod tests {
862883
[Endpoint::new((Ipv4Addr::new(20, 20, 20, 20), 1234).into())].into();
863884

864885
cluster.insert(Some(nl02.into()), Some(nl1.clone()), not_expected.clone());
865-
assert_eq!(cluster.get(&Some(nl1.clone())).unwrap().endpoints, expected);
886+
assert_eq!(
887+
cluster.map.get(&Some(nl1.clone())).unwrap().endpoints,
888+
expected
889+
);
866890

867891
cluster.remove_locality(Some(nl01.into()), &Some(nl1.clone()));
868892

869893
cluster.insert(Some(nl02.into()), Some(nl1.clone()), not_expected.clone());
870-
assert_eq!(cluster.get(&Some(nl1)).unwrap().endpoints, not_expected);
894+
assert_eq!(cluster.map.get(&Some(nl1)).unwrap().endpoints, not_expected);
871895
}
872896

873897
#[test]
@@ -886,7 +910,7 @@ mod tests {
886910
assert!(cluster.remove_endpoint(ep));
887911
}
888912

889-
assert!(cluster.get(&None).is_none());
913+
assert!(cluster.map.get(&None).is_none());
890914
}
891915

892916
{
@@ -895,7 +919,7 @@ mod tests {
895919
assert!(cluster.remove_endpoint_if(|_ep| true));
896920
}
897921

898-
assert!(cluster.get(&None).is_none());
922+
assert!(cluster.map.get(&None).is_none());
899923
}
900924
}
901925
}

0 commit comments

Comments
 (0)