From d187afd9eb6c6ce202f6ec599e4fed53a703a9be Mon Sep 17 00:00:00 2001 From: Marco Molari Date: Fri, 31 Oct 2025 08:26:18 +0100 Subject: [PATCH 1/4] test: keep polytomies for inference test --- packages/treetime/src/commands/timetree/inference/runner.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/packages/treetime/src/commands/timetree/inference/runner.rs b/packages/treetime/src/commands/timetree/inference/runner.rs index 72c91ff6..336571f2 100644 --- a/packages/treetime/src/commands/timetree/inference/runner.rs +++ b/packages/treetime/src/commands/timetree/inference/runner.rs @@ -208,6 +208,7 @@ mod tests { // --dates data/flu/h3n2/20/metadata.tsv \ // --branch-length-mode input \ // --sequence-len 1400 \ + // --keep-polytomies \ // --outdir tmp/python_treetime_baseline // // Output: tmp/python_treetime_baseline/dates.tsv @@ -216,6 +217,7 @@ mod tests { // for branch length distributions (same as this test), enabling direct comparison. let expected = btreemap! { o!("NODE_0000017") => 1996.974064, + o!("NODE_0000018") => 1997.116240, o!("NODE_0000012") => 1998.499705, o!("A/Canterbury/58/2000|CY009150|09/05/2000|New_Zealand||H3N2/8-1416") => 2000.681725, o!("NODE_0000011") => 1998.763998, From 90f10c2d3b91c5b2014f14dcd96a677bcf1932a7 Mon Sep 17 00:00:00 2001 From: Marco Molari Date: Fri, 31 Oct 2025 08:27:20 +0100 Subject: [PATCH 2/4] fix: ensure likely time is set after refining distribution from parent --- .../treetime/src/commands/timetree/inference/forward_pass.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/treetime/src/commands/timetree/inference/forward_pass.rs b/packages/treetime/src/commands/timetree/inference/forward_pass.rs index 01473f98..5397f543 100644 --- a/packages/treetime/src/commands/timetree/inference/forward_pass.rs +++ b/packages/treetime/src/commands/timetree/inference/forward_pass.rs @@ -18,8 +18,8 @@ pub fn propagate_distributions_forward(graph: &GraphAncestral) -> Result<(), Rep fn propagate_distributions_forward_single_node( node: &mut GraphNodeForward, ) -> Result<(), Report> { - set_likely_time(node); refine_distribution_from_parent(node)?; + set_likely_time(node); Ok(()) } From c8a7ff8e992861858bd13f446cac816d83c58be8 Mon Sep 17 00:00:00 2001 From: Marco Molari Date: Fri, 31 Oct 2025 18:49:05 +0100 Subject: [PATCH 3/4] fix: range-function convolution: shifted domain --- .../distribution/distribution_convolution.rs | 32 +++++++++++++------ 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/packages/treetime/src/distribution/distribution_convolution.rs b/packages/treetime/src/distribution/distribution_convolution.rs index 1e8c7de6..c1f2b8f3 100644 --- a/packages/treetime/src/distribution/distribution_convolution.rs +++ b/packages/treetime/src/distribution/distribution_convolution.rs @@ -26,7 +26,7 @@ pub fn distribution_convolution(a: &Distribution, b: &Distribution) -> Result { - Ok(convolution_point_function(a, b)?) // + Ok(Distribution::Function(convolution_point_function(a, b)?)) }, (Distribution::Range(a), Distribution::Function(b)) | (Distribution::Function(b), Distribution::Range(a)) => { Ok(convolution_range_function(a, b)) // @@ -73,20 +73,32 @@ fn convolution_point_range(p: &DistributionPoint, r: &DistributionRange, f: &DistributionFunction, -) -> Result { +) -> Result, Report> { let t = f.t().map(|t| t + p.t()); let y = f.y().map(|y| y * p.amplitude()); - Distribution::function(t, y) + DistributionFunction::new(t, y) } fn convolution_range_function(r: &DistributionRange, f: &DistributionFunction) -> Distribution { - let t_out = f.t().clone(); - let mut y_out = Array1::zeros(f.y().len()); + // split in a convolution with + // - a point distribution (taking care of the shift + amplitude) + // - an interval centered on zero and of a fixed width (taking care of the smoothing) + + let shift = (r.start() + r.end()) / 2.0; + let amplitude = r.amplitude(); + let width = r.end() - r.start(); + + let point_distr = DistributionPoint::new(shift, amplitude); + let shifted_function = convolution_point_function(&point_distr, f).unwrap(); - for (i, &ti) in f.t().iter().enumerate() { - let mask = f.t().mapv(|x| (x >= ti - r.end()) && (x <= ti - r.start())); + // Convolution with a range centered on zero and of given width + let t_out = shifted_function.t().clone(); + let mut y_out = Array1::zeros(shifted_function.y().len()); + + for (i, &ti) in shifted_function.t().iter().enumerate() { + let mask = shifted_function.t().mapv(|x| (x - ti).abs() <= width / 2.0); let filtered_y = f.y() * &mask.mapv(|x| if x { 1.0 } else { 0.0 }); - y_out[i] = r.amplitude() * filtered_y.sum(); + y_out[i] = filtered_y.sum(); } Distribution::function(t_out, y_out).unwrap() @@ -280,8 +292,8 @@ mod tests { let f = Distribution::function(x, y).unwrap(); let actual = distribution_convolution(&r, &f).unwrap(); - let x = array![0.0, 1.0, 2.0, 3.0, 4.0, 5.0]; - let y = array![0.0, 0.0, 2.0, 2.0, 6.0, 6.0]; + let x = array![2.0, 3.0, 4.0, 5.0, 6.0, 7.0]; + let y = array![1.0, 1.0, 3.0, 3.0, 3.0, 1.0]; let expected = Distribution::function(x, y).unwrap(); assert_eq!(expected, actual); From 2a7c5722e931db93f4e9f086acb0732f6be25a65 Mon Sep 17 00:00:00 2001 From: Marco Molari Date: Fri, 31 Oct 2025 18:55:26 +0100 Subject: [PATCH 4/4] fix: forgotten the amplitude --- .../treetime/src/distribution/distribution_convolution.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/treetime/src/distribution/distribution_convolution.rs b/packages/treetime/src/distribution/distribution_convolution.rs index c1f2b8f3..1d5290b6 100644 --- a/packages/treetime/src/distribution/distribution_convolution.rs +++ b/packages/treetime/src/distribution/distribution_convolution.rs @@ -97,7 +97,7 @@ fn convolution_range_function(r: &DistributionRange, f: &DistributionFuncti for (i, &ti) in shifted_function.t().iter().enumerate() { let mask = shifted_function.t().mapv(|x| (x - ti).abs() <= width / 2.0); - let filtered_y = f.y() * &mask.mapv(|x| if x { 1.0 } else { 0.0 }); + let filtered_y = shifted_function.y() * &mask.mapv(|x| if x { 1.0 } else { 0.0 }); y_out[i] = filtered_y.sum(); } @@ -293,7 +293,7 @@ mod tests { let actual = distribution_convolution(&r, &f).unwrap(); let x = array![2.0, 3.0, 4.0, 5.0, 6.0, 7.0]; - let y = array![1.0, 1.0, 3.0, 3.0, 3.0, 1.0]; + let y = array![2.0, 2.0, 6.0, 6.0, 6.0, 2.0]; let expected = Distribution::function(x, y).unwrap(); assert_eq!(expected, actual);