diff --git a/benches/cluster_map.rs b/benches/cluster_map.rs index 973e50d097..3765defd33 100644 --- a/benches/cluster_map.rs +++ b/benches/cluster_map.rs @@ -17,20 +17,18 @@ mod serde { fn serialize_to_protobuf(cm: &ClusterMap) -> Vec { let mut resources = Vec::new(); - for cluster in cm.iter() { - resources.push( - Resource::Cluster(Cluster { - locality: cluster.key().clone().map(From::from), - endpoints: cluster - .endpoints - .iter() - .map(TryFrom::try_from) - .collect::>() - .unwrap(), - }) - .try_encode() - .unwrap(), - ); + for resource in cm.iter_with(|locality, endpoint_set| { + Resource::Cluster(Cluster { + locality: locality.clone().map(From::from), + endpoints: endpoint_set + .endpoints + .iter() + .map(TryFrom::try_from) + .collect::>() + .unwrap(), + }) + }) { + resources.push(resource.try_encode().unwrap()); } resources @@ -111,11 +109,11 @@ mod ops { use shared::{GenCluster, gen_cluster_map}; fn compute_hash(gc: &GenCluster) -> usize { - let mut total_endpoints = 0; - - for kv in gc.cm.iter() { - total_endpoints += kv.endpoints.len(); - } + let total_endpoints = gc + .cm + .iter_with(|_locality, endpoint_set| endpoint_set.len()) + .iter() + .sum(); assert_eq!(total_endpoints, gc.total_endpoints); total_endpoints diff --git a/benches/shared.rs b/benches/shared.rs index 80b038cee4..44a5974f8e 100644 --- a/benches/shared.rs +++ b/benches/shared.rs @@ -690,7 +690,7 @@ pub fn gen_cluster_map(token_kind: TokenKind) -> GenCluster { // Now actually insert the endpoints, now that the order of keys is established, // annoying, but note we split out iteration versus insertion, otherwise we deadlock - let keys: Vec<_> = cm.iter().map(|kv| kv.key().clone()).collect(); + let keys = cm.iter_with(|locality, _endpoint_set| locality.clone()); let mut sets = std::collections::BTreeMap::new(); let mut token_generator = match token_kind { diff --git a/benches/token_router.rs b/benches/token_router.rs index 30ebf2b283..976f55879a 100644 --- a/benches/token_router.rs +++ b/benches/token_router.rs @@ -9,18 +9,22 @@ fn token_router(b: Bencher<'_, '_>, token_kind: &str) { let filter = TokenRouter::default(); let gc = shared::gen_cluster_map::<42>(token_kind.parse().unwrap()); - let mut tokens = Vec::new(); - let cm = std::sync::Arc::new(gc.cm); // Calculate the amount of bytes for all the tokens - for eps in cm.iter() { - for ep in &eps.value().endpoints { - for tok in ep.metadata.known.tokens.iter() { - tokens.push(tok.clone()); + let tokens: Vec> = cm + .iter_with(|_locality, endpoint_set| { + let mut tokens: Vec> = Vec::new(); + for ep in &endpoint_set.endpoints { + for tok in ep.metadata.known.tokens.iter() { + tokens.push(tok.clone()); + } } - } - } + tokens + }) + .into_iter() + .flatten() + .collect(); let total_token_size: usize = tokens.iter().map(|t| t.len()).sum(); let pool = std::sync::Arc::new(quilkin::collections::BufferPool::new(1, 1)); diff --git a/crates/test/tests/mesh.rs b/crates/test/tests/mesh.rs index 860055be03..1cd8d3c65e 100644 --- a/crates/test/tests/mesh.rs +++ b/crates/test/tests/mesh.rs @@ -189,12 +189,17 @@ trace_test!(datacenter_discovery, { ) -> bool { let dcs = config.dyn_cfg.datacenters().unwrap().read(); - for i in dcs.iter() { - dbg!(which, i.key(), i.value()); - } + drop(dcs.iter_with(|key, value| { + dbg!(which, key, value); + })); + + let entries = dcs + .iter_with(|ip_addr, dc| (ip_addr.clone(), dc.clone())) + .into_iter() + .collect::>(); - let ipv4_dc = dcs.get(&std::net::Ipv4Addr::LOCALHOST.into()); - let ipv6_dc = dcs.get(&std::net::Ipv6Addr::LOCALHOST.into()); + let ipv4_dc = entries.get(&std::net::Ipv4Addr::LOCALHOST.into()); + let ipv6_dc = entries.get(&std::net::Ipv6Addr::LOCALHOST.into()); if counter > 0 { match (ipv4_dc, ipv6_dc) { diff --git a/crates/xds/src/client.rs b/crates/xds/src/client.rs index d1774b1af2..1a1b4d29e3 100644 --- a/crates/xds/src/client.rs +++ b/crates/xds/src/client.rs @@ -446,7 +446,7 @@ impl DeltaClientStream { ) -> Result<()> { crate::metrics::actions_total(KIND_CLIENT, "refresh").inc(); for (rt, names) in subs { - let initial_resource_versions = local.get(rt).clone(); + let initial_resource_versions = local.get_versions(rt); self.req_tx .send(DeltaDiscoveryRequest { node: Some(Node { diff --git a/crates/xds/src/config.rs b/crates/xds/src/config.rs index e1b467a437..f0a0aee51f 100644 --- a/crates/xds/src/config.rs +++ b/crates/xds/src/config.rs @@ -40,7 +40,7 @@ impl LocalVersions { } #[inline] - pub fn get(&self, ty: &str) -> parking_lot::MutexGuard<'_, VersionMap> { + fn get_map_guard(&self, ty: &str) -> parking_lot::MutexGuard<'_, VersionMap> { let g = self .versions .iter() @@ -53,6 +53,33 @@ impl LocalVersions { panic!("unable to retrieve `{ty}` versions, available versions are {versions:?}"); } } + + #[inline] + pub fn update_versions( + &self, + type_url: &str, + removed_resources: &Vec, + updated_resources: Vec<(String, String)>, + ) { + let mut guard = self.get_map_guard(type_url); + + // Remove any resources the upstream server has removed/doesn't have, + // we do this before applying any new/updated resources in case a + // resource is in both lists, though really that would be a bug in + // the upstream server + for removed in removed_resources { + guard.remove(removed); + } + + for (k, v) in updated_resources { + guard.insert(k, v); + } + } + + #[inline] + pub fn get_versions(&self, type_url: &str) -> VersionMap { + self.get_map_guard(type_url).clone() + } } pub struct ClientState { @@ -268,19 +295,7 @@ pub fn handle_delta_discovery_responses( let res = config.apply_delta(&type_url, response.resources, &response.removed_resources, remote_addr); if res.is_ok() { - let mut lock = local.get(&type_url); - - // Remove any resources the upstream server has removed/doesn't have, - // we do this before applying any new/updated resources in case a - // resource is in both lists, though really that would be a bug in - // the upstream server - for removed in response.removed_resources { - lock.remove(&removed); - } - - for (k, v) in version_map { - lock.insert(k, v); - } + local.update_versions(&type_url, &response.removed_resources, version_map); } res diff --git a/src/config.rs b/src/config.rs index a3c8faf022..7a259ab431 100644 --- a/src/config.rs +++ b/src/config.rs @@ -594,32 +594,40 @@ impl Config { } if let Some(datacenters) = self.dyn_cfg.datacenters() { - for entry in datacenters.read().iter() { - let host = entry.key().to_string(); - let qcmp_port = entry.qcmp_port; - let version = - resource_version(entry.icao_code.to_string().as_str(), qcmp_port); - - if client_state.version_matches(&host, &version) { - continue; - } + let dc_resource_transformer = + |ip_addr: &std::net::IpAddr, + dc: &Datacenter| + -> eyre::Result> { + let host = ip_addr.to_string(); + let qcmp_port = dc.qcmp_port; + let version = + resource_version(dc.icao_code.to_string().as_str(), qcmp_port); + + if client_state.version_matches(&host, &version) { + return Ok(None); + } - let resource = crate::xds::Resource::Datacenter( - crate::net::cluster::proto::Datacenter { - qcmp_port: qcmp_port as _, - icao_code: entry.icao_code.to_string(), - host: host.clone(), - }, - ); + let resource = crate::xds::Resource::Datacenter( + crate::net::cluster::proto::Datacenter { + qcmp_port: qcmp_port as _, + icao_code: dc.icao_code.to_string(), + host: host.clone(), + }, + ); + + Ok(Some(XdsResource { + name: host, + version, + resource: Some(resource.try_encode()?), + aliases: Vec::new(), + ttl: None, + cache_control: None, + })) + }; - resources.push(XdsResource { - name: host, - version, - resource: Some(resource.try_encode()?), - aliases: Vec::new(), - ttl: None, - cache_control: None, - }); + for resource in datacenters.read().iter_with(dc_resource_transformer) { + let Some(resource) = resource? else { continue }; + resources.push(resource); } { @@ -628,7 +636,7 @@ impl Config { let Ok(addr) = key.parse() else { continue; }; - if dc.get(&addr).is_none() { + if !dc.exists(&addr) { removed.insert(key.clone()); } } @@ -636,40 +644,43 @@ impl Config { } } ResourceType::Cluster => { - let mut push = |key: &Option, - value: &crate::net::cluster::EndpointSet| - -> crate::Result<()> { - let version = value.version().to_string(); - let key_s = key.as_ref().map(|k| k.to_string()).unwrap_or_default(); - - if client_state.version_matches(&key_s, &version) { - return Ok(()); - } - - let resource = crate::xds::Resource::Cluster( - quilkin_xds::generated::quilkin::config::v1alpha1::Cluster { - locality: key.clone().map(|l| l.into()), - endpoints: value.endpoints.iter().map(|ep| ep.into()).collect(), - }, - ); + let cluster_resource_transformer = + |key: &Option, + value: &crate::net::cluster::EndpointSet| + -> crate::Result> { + let version = value.version().to_string(); + let key_s = key.as_ref().map(|k| k.to_string()).unwrap_or_default(); + + if client_state.version_matches(&key_s, &version) { + return Ok(None); + } - resources.push(XdsResource { - name: key_s, - version, - resource: Some(resource.try_encode()?), - ..Default::default() - }); + let resource = crate::xds::Resource::Cluster( + quilkin_xds::generated::quilkin::config::v1alpha1::Cluster { + locality: key.clone().map(|l| l.into()), + endpoints: value.endpoints.iter().map(|ep| ep.into()).collect(), + }, + ); - Ok(()) - }; + Ok(Some(XdsResource { + name: key_s, + version, + resource: Some(resource.try_encode()?), + ..Default::default() + })) + }; let Some(clusters) = self.dyn_cfg.clusters() else { break 'append; }; if client_state.subscribed.is_empty() { - for cluster in clusters.read().iter() { - push(cluster.key(), cluster.value())?; + for cluster_resources in + clusters.read().iter_with(cluster_resource_transformer) + { + if let Some(clr) = cluster_resources? { + resources.push(clr); + } } } else { for locality in client_state.subscribed.iter().filter_map(|name| { @@ -679,8 +690,14 @@ impl Config { name.parse().ok().map(Some) } }) { - if let Some(cluster) = clusters.read().get(&locality) { - push(cluster.key(), cluster.value())?; + if let Some(cluster_resource) = + clusters.read().with_value(&locality, |entry| { + cluster_resource_transformer(entry.key(), entry.value()) + }) + { + if let Some(clr) = cluster_resource? { + resources.push(clr); + } } } }; @@ -688,9 +705,7 @@ impl Config { // Currently, we have exactly _one_ special case for removed resources, which // is when ClusterMap::update_unlocated_endpoints is called to move the None // locality endpoints to another one, so we just detect that case manually - if client_state.versions.contains_key("") - && clusters.read().get(&None).is_none() - { + if client_state.versions.contains_key("") && !clusters.read().exists(&None) { removed.insert("".into()); } } diff --git a/src/config/datacenter.rs b/src/config/datacenter.rs index a5d3501cab..d31fea57e1 100644 --- a/src/config/datacenter.rs +++ b/src/config/datacenter.rs @@ -37,13 +37,24 @@ impl DatacenterMap { } #[inline] - pub fn get(&self, key: &IpAddr) -> Option> { - self.map.get(key) + pub fn exists(&self, key: &IpAddr) -> bool { + self.map.get(key).is_some() } + /// Iterates over the entries in the `DatacenterMap` with the given func + /// + /// This ensures that the dashmap entry references are never held across await boundaries as + /// the func cannot be async. #[inline] - pub fn iter(&self) -> dashmap::iter::Iter<'_, IpAddr, Datacenter> { - self.map.iter() + pub fn iter_with(&self, func: F) -> Vec + where + F: for<'a> Fn(&'a IpAddr, &'a Datacenter) -> T, + { + let mut results: Vec = Vec::new(); + for entry in self.map.iter() { + results.push(func(entry.key(), entry.value())); + } + results } #[inline] @@ -113,8 +124,8 @@ impl PartialEq for DatacenterMap { return false; } - for a in self.iter() { - match rhs.get(a.key()).filter(|b| *a.value() == **b) { + for a in self.map.iter() { + match rhs.map.get(a.key()).filter(|b| *a.value() == **b) { Some(_) => {} None => return false, } diff --git a/src/metrics.rs b/src/metrics.rs index fcf76b4892..9749a8550a 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -676,14 +676,13 @@ pub(crate) fn apply_clusters(clusters: &crate::config::Watch { version: AtomicU64, } -type DashMapRef<'inner> = dashmap::mapref::one::Ref<'inner, Option, EndpointSet>; - impl ClusterMap { pub fn new() -> Self { Self::default() @@ -367,8 +365,16 @@ where } #[inline] - pub fn get(&self, key: &Option) -> Option> { - self.map.get(key) + pub fn exists(&self, key: &Option) -> bool { + self.map.get(key).is_some() + } + + #[inline] + pub fn with_value(&self, key: &Option, func: F) -> Option + where + F: FnOnce(dashmap::mapref::one::Ref<'_, Option, EndpointSet>) -> T, + { + self.map.get(key).map(func) } #[inline] @@ -376,6 +382,22 @@ where self.insert(None, None, endpoints); } + /// Iterates over the entries in the `ClusterMap` with the given func + /// + /// This ensures that the dashmap entry references are never held across await boundaries as + /// the func cannot be async. + #[inline] + pub fn iter_with(&self, func: F) -> Vec + where + F: for<'a> Fn(&'a Option, &'a EndpointSet) -> T, + { + let mut results: Vec = Vec::new(); + for entry in self.map.iter() { + results.push(func(entry.key(), entry.value())); + } + results + } + #[inline] pub fn remove_endpoint(&self, needle: &Endpoint) -> bool { let locality = 'l: { @@ -430,11 +452,6 @@ where true } - #[inline] - pub fn iter(&self) -> dashmap::iter::Iter<'_, Option, EndpointSet, S> { - self.map.iter() - } - #[inline] pub fn replace( &self, @@ -477,7 +494,7 @@ where } pub fn nth_endpoint(&self, mut index: usize) -> Option { - for set in self.iter() { + for set in self.map.iter() { let set = &set.value().endpoints; if index < set.len() { return set.iter().nth(index).cloned(); @@ -492,7 +509,7 @@ where pub fn filter_endpoints(&self, f: impl Fn(&Endpoint) -> bool) -> Vec { let mut endpoints = Vec::new(); - for set in self.iter() { + for set in self.map.iter() { for endpoint in set.endpoints.iter().filter(|e| (f)(e)) { endpoints.push(endpoint.clone()); } @@ -639,8 +656,9 @@ where S: Default + std::hash::BuildHasher + Clone, { fn eq(&self, rhs: &Self) -> bool { - for a in self.iter() { + for a in self.map.iter() { match rhs + .map .get(a.key()) .filter(|b| a.value().endpoints == b.endpoints) { @@ -808,16 +826,18 @@ mod tests { cluster1.insert(None, Some(nl1.clone()), [endpoint.clone()].into()); cluster1.insert(None, Some(de1.clone()), [endpoint.clone()].into()); - assert_eq!(cluster1.get(&Some(nl1.clone())).unwrap().len(), 1); + assert_eq!(cluster1.map.get(&Some(nl1.clone())).unwrap().len(), 1); assert!( cluster1 + .map .get(&Some(nl1.clone())) .unwrap() .contains(&endpoint) ); - assert_eq!(cluster1.get(&Some(de1.clone())).unwrap().len(), 1); + assert_eq!(cluster1.map.get(&Some(de1.clone())).unwrap().len(), 1); assert!( cluster1 + .map .get(&Some(de1.clone())) .unwrap() .contains(&endpoint) @@ -827,10 +847,11 @@ mod tests { cluster1.insert(None, Some(de1.clone()), [endpoint.clone()].into()); - assert_eq!(cluster1.get(&Some(nl1.clone())).unwrap().len(), 1); - assert_eq!(cluster1.get(&Some(de1.clone())).unwrap().len(), 1); + assert_eq!(cluster1.map.get(&Some(nl1.clone())).unwrap().len(), 1); + assert_eq!(cluster1.map.get(&Some(de1.clone())).unwrap().len(), 1); assert!( cluster1 + .map .get(&Some(de1.clone())) .unwrap() .contains(&endpoint) @@ -838,8 +859,8 @@ mod tests { cluster1.insert(None, Some(de1.clone()), <_>::default()); - assert_eq!(cluster1.get(&Some(nl1.clone())).unwrap().len(), 1); - assert!(cluster1.get(&Some(de1.clone())).unwrap().is_empty()); + assert_eq!(cluster1.map.get(&Some(nl1.clone())).unwrap().len(), 1); + assert!(cluster1.map.get(&Some(de1.clone())).unwrap().is_empty()); } #[test] @@ -862,12 +883,15 @@ mod tests { [Endpoint::new((Ipv4Addr::new(20, 20, 20, 20), 1234).into())].into(); cluster.insert(Some(nl02.into()), Some(nl1.clone()), not_expected.clone()); - assert_eq!(cluster.get(&Some(nl1.clone())).unwrap().endpoints, expected); + assert_eq!( + cluster.map.get(&Some(nl1.clone())).unwrap().endpoints, + expected + ); cluster.remove_locality(Some(nl01.into()), &Some(nl1.clone())); cluster.insert(Some(nl02.into()), Some(nl1.clone()), not_expected.clone()); - assert_eq!(cluster.get(&Some(nl1)).unwrap().endpoints, not_expected); + assert_eq!(cluster.map.get(&Some(nl1)).unwrap().endpoints, not_expected); } #[test] @@ -886,7 +910,7 @@ mod tests { assert!(cluster.remove_endpoint(ep)); } - assert!(cluster.get(&None).is_none()); + assert!(cluster.map.get(&None).is_none()); } { @@ -895,7 +919,7 @@ mod tests { assert!(cluster.remove_endpoint_if(|_ep| true)); } - assert!(cluster.get(&None).is_none()); + assert!(cluster.map.get(&None).is_none()); } } } diff --git a/src/net/phoenix.rs b/src/net/phoenix.rs index f7f49f5b74..dc16cf0bd7 100644 --- a/src/net/phoenix.rs +++ b/src/net/phoenix.rs @@ -388,10 +388,10 @@ impl Phoenix { self.nodes.remove(&removed); } - for entry in dcs.iter() { - let addr = (*entry.key(), entry.value().qcmp_port).into(); - self.add_node_if_not_exists(addr, entry.value().icao_code); - } + drop(dcs.iter_with(|ip_addr, dc| { + let socket_addr = (*ip_addr, dc.qcmp_port).into(); + self.add_node_if_not_exists(socket_addr, dc.icao_code); + })); } }