Skip to content

Commit 21f62f7

Browse files
authored
Refactor phoenix update task, add tests, fix edge cases (#1260)
1 parent 45d3a6d commit 21f62f7

File tree

1 file changed

+152
-57
lines changed

1 file changed

+152
-57
lines changed

src/net/phoenix.rs

Lines changed: 152 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -301,55 +301,63 @@ impl<M: Measurement + 'static> Phoenix<M> {
301301
Builder::new(measurement)
302302
}
303303

304-
/// Starts the background update task to continously sample from nodes
305-
/// and update their coordinates.
306-
pub async fn background_update_task(&self) {
307-
let mut current_interval = self.interval_range.start;
308-
let mut first = Some(self.all_nodes());
304+
async fn update(&self, mut current_interval: std::time::Duration) -> std::time::Duration {
305+
let mut total_difference = 0;
306+
let mut count = 0;
309307

310-
loop {
311-
let mut total_difference = 0;
312-
let mut count = 0;
313-
314-
let nodes_to_probe = first
315-
.take()
316-
.unwrap_or_else(|| self.random_subset_of_nodes());
317-
318-
for address in nodes_to_probe {
319-
let Some(mut node) = self.nodes.get_mut(&address) else {
320-
tracing::debug!(%address, "node removed between selection and measurement");
321-
continue;
322-
};
323-
324-
match self.measurement.measure_distance(address).await {
325-
Ok(distance) => {
326-
node.adjust_coordinates(distance);
327-
total_difference += distance.total_nanos();
328-
count += 1;
329-
}
330-
Err(error) => {
308+
let nodes = self.select_nodes_to_probe();
309+
310+
for address in nodes {
311+
let Some(mut node) = self.nodes.get_mut(&address) else {
312+
tracing::debug!(%address, "node removed between selection and measurement");
313+
continue;
314+
};
315+
316+
match self.measurement.measure_distance(address).await {
317+
Ok(distance) => {
318+
node.adjust_coordinates(distance);
319+
total_difference += distance.total_nanos();
320+
count += 1;
321+
}
322+
Err(error) => {
323+
node.increase_error_estimate();
324+
let consecutive_errors = node.consecutive_errors();
325+
if consecutive_errors > 3 {
326+
tracing::warn!(%address, %error, %consecutive_errors, "error measuring distance");
327+
} else {
331328
tracing::debug!(%address, %error, "error measuring distance");
332-
node.increase_error_estimate();
333329
}
334330
}
335331
}
332+
}
336333

337-
if count > 0 {
338-
let avg_difference_ns = total_difference / count;
334+
if count > 0 {
335+
let avg_difference_ns = total_difference / count;
339336

340-
// Adjust the interval based on the avg_difference
341-
if Duration::from_nanos(avg_difference_ns as u64) < self.stability_threshold {
342-
current_interval += self.adjustment_duration;
343-
} else {
344-
current_interval -= self.adjustment_duration;
345-
}
346-
347-
// Ensure current_interval remains within bounds
348-
current_interval =
349-
current_interval.clamp(self.interval_range.start, self.interval_range.end);
337+
// Adjust the interval based on the avg_difference
338+
if Duration::from_nanos(avg_difference_ns as u64) < self.stability_threshold {
339+
current_interval += self.adjustment_duration;
340+
} else {
341+
current_interval -= self.adjustment_duration;
350342
}
351343

352-
let _ = self.update_watcher.0.send(());
344+
// Ensure current_interval remains within bounds
345+
current_interval =
346+
current_interval.clamp(self.interval_range.start, self.interval_range.end);
347+
}
348+
349+
let _ = self.update_watcher.0.send(());
350+
current_interval
351+
}
352+
353+
/// Starts the background update task to continously sample from nodes
354+
/// and update their coordinates.
355+
pub async fn background_update_task(&self) {
356+
let mut current_interval = self.interval_range.start;
357+
358+
loop {
359+
current_interval = self.update(current_interval).await;
360+
353361
tokio::time::sleep(current_interval).await;
354362
}
355363
}
@@ -358,33 +366,40 @@ impl<M: Measurement + 'static> Phoenix<M> {
358366
self.update_watcher.1.clone()
359367
}
360368

369+
#[allow(dead_code)]
361370
fn all_nodes(&self) -> Vec<SocketAddr> {
362371
self.nodes
363372
.iter()
364373
.map(|entry| *entry.key())
365374
.collect::<Vec<_>>()
366375
}
367376

368-
fn random_subset_of_nodes(&self) -> Vec<SocketAddr> {
377+
/// Returns a set of node addresses to probe.
378+
///
379+
/// - Always returns at least 1 node unless the list of nodes is empty
380+
/// - Always returns all of the nodes that have not been mapped yet
381+
/// - Returns a randomly selected subset of nodes that have been mapped
382+
fn select_nodes_to_probe(&self) -> Vec<SocketAddr> {
369383
use rand::seq::SliceRandom;
370-
let unmapped_nodes = self
384+
385+
let (unmapped, mut mapped): (Vec<_>, Vec<_>) = self
371386
.nodes
372387
.iter()
373-
.filter(|entry| entry.coordinates.is_none());
374-
375-
if unmapped_nodes.clone().count() > 0 {
376-
unmapped_nodes.map(|entry| *entry.key()).collect()
377-
} else {
378-
let mut nodes = self
379-
.nodes
380-
.iter()
381-
.map(|entry| *entry.key())
382-
.collect::<Vec<_>>();
383-
nodes.shuffle(&mut rand::rng());
384-
let subset_size = (nodes.len() as f64 * self.subset_percentage).abs() as usize;
385-
386-
nodes[..subset_size].to_vec()
387-
}
388+
.partition(|entry| entry.coordinates.is_none());
389+
390+
mapped.shuffle(&mut rand::rng());
391+
392+
// Select a subset of the already mapped nodes, but always at least one node
393+
let subset_size = (mapped.len() as f64 * self.subset_percentage)
394+
.abs()
395+
.max(1.0) as usize;
396+
397+
mapped
398+
.iter()
399+
.map(|entry| *entry.key())
400+
.take(subset_size)
401+
.chain(unmapped.iter().map(|entry| *entry.key())) // Always include all unmapped nodes
402+
.collect()
388403
}
389404

390405
#[cfg(test)]
@@ -582,6 +597,7 @@ struct Node {
582597
coordinates: Option<Coordinates>,
583598
icao_code: IcaoCode,
584599
error_estimate: f64,
600+
consecutive_errors: u64,
585601
}
586602

587603
impl Node {
@@ -591,15 +607,22 @@ impl Node {
591607
coordinates: None,
592608
icao_code,
593609
error_estimate: 1.0,
610+
consecutive_errors: 0,
594611
}
595612
}
596613

614+
fn consecutive_errors(&self) -> u64 {
615+
self.consecutive_errors
616+
}
617+
597618
fn increase_error_estimate(&mut self) {
598619
self.error_estimate += 0.1;
620+
self.consecutive_errors += 1;
599621
crate::metrics::phoenix_distance_error_estimate(self.icao_code).set(self.error_estimate);
600622
}
601623

602624
fn adjust_coordinates(&mut self, distance: DistanceMeasure) {
625+
self.consecutive_errors = 0;
603626
let incoming = distance.incoming.nanos() as f64;
604627
let outgoing = distance.outgoing.nanos() as f64;
605628

@@ -716,6 +739,78 @@ mod tests {
716739
});
717740
}
718741

742+
#[test]
743+
fn zero_nodes_return_empty_subset() {
744+
let phoenix = Phoenix::new(MockMeasurement {
745+
latencies: <_>::default(),
746+
});
747+
748+
assert_eq!(phoenix.select_nodes_to_probe(), vec![]);
749+
}
750+
751+
#[tokio::test]
752+
async fn one_node_returns_single_node_subset() {
753+
let phoenix = Phoenix::new(MockMeasurement {
754+
latencies: <_>::default(),
755+
});
756+
757+
let socket_addr = "127.0.0.1:8080".parse().unwrap();
758+
phoenix.add_node(socket_addr, abcd());
759+
760+
// First time it will be returned as part of "unmapped_nodes"
761+
assert_eq!(phoenix.select_nodes_to_probe(), vec![socket_addr]);
762+
phoenix.measure_all_nodes().await;
763+
// After it has been measured it should still be returned so we don't get stuck without
764+
// ever making additional measurements
765+
assert_eq!(phoenix.select_nodes_to_probe(), vec![socket_addr]);
766+
}
767+
768+
#[tokio::test]
769+
async fn select_nodes_to_probe() {
770+
let latencies = HashMap::from([
771+
("127.0.0.1:8080".parse().unwrap(), (100, 100).into()),
772+
("127.0.0.1:8081".parse().unwrap(), (200, 200).into()),
773+
("127.0.0.1:8082".parse().unwrap(), (200, 200).into()),
774+
("127.0.0.1:8083".parse().unwrap(), (200, 200).into()),
775+
("127.0.0.1:8084".parse().unwrap(), (200, 200).into()),
776+
]);
777+
let failed_address = "127.0.0.1:8080".parse::<SocketAddr>().unwrap();
778+
let failed_addresses = Arc::new(Mutex::new(HashSet::from([failed_address])));
779+
let phoenix = Phoenix::builder(FailedAddressesMock {
780+
latencies,
781+
failed_addresses,
782+
})
783+
.subset_percentage(0.25)
784+
.build();
785+
786+
phoenix.add_node("127.0.0.1:8080".parse().unwrap(), abcd());
787+
phoenix.add_node("127.0.0.1:8081".parse().unwrap(), efgh());
788+
phoenix.add_node("127.0.0.1:8082".parse().unwrap(), efgh());
789+
phoenix.add_node("127.0.0.1:8083".parse().unwrap(), efgh());
790+
phoenix.add_node("127.0.0.1:8084".parse().unwrap(), efgh());
791+
792+
let mut nodes_to_probe = phoenix.select_nodes_to_probe();
793+
nodes_to_probe.sort();
794+
let expected_nodes_to_probe = vec![
795+
"127.0.0.1:8080".parse().unwrap(),
796+
"127.0.0.1:8081".parse().unwrap(),
797+
"127.0.0.1:8082".parse().unwrap(),
798+
"127.0.0.1:8083".parse().unwrap(),
799+
"127.0.0.1:8084".parse().unwrap(),
800+
];
801+
assert_eq!(nodes_to_probe, expected_nodes_to_probe);
802+
803+
phoenix.measure_all_nodes().await;
804+
805+
// Ensure that we always get the node that has not been mapped yet, as well as 1 out of the
806+
// 4 mapped nodes due to the 25% subset percentage
807+
for _ in 0..10 {
808+
let nodes_to_probe = phoenix.select_nodes_to_probe();
809+
assert_eq!(nodes_to_probe.len(), 2);
810+
assert!(nodes_to_probe.contains(&failed_address));
811+
}
812+
}
813+
719814
#[tokio::test]
720815
async fn coordinates_adjustment() {
721816
let mut mock_latencies = HashMap::new();

0 commit comments

Comments
 (0)