Skip to content

Commit 7fdb264

Browse files
mmolariivan-aksamentov
authored andcommitted
fix: forgot multiplication by grid size in convolution_range_function
1 parent 63eb5b8 commit 7fdb264

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

packages/treetime/src/distribution/distribution_convolution.rs

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ pub fn distribution_convolution(a: &Distribution, b: &Distribution) -> Result<Di
2828
Ok(Distribution::Function(convolution_point_function(a, b)?))
2929
},
3030
(Distribution::Range(a), Distribution::Function(b)) | (Distribution::Function(b), Distribution::Range(a)) => {
31-
Ok(convolution_range_function(a, b)) //
31+
convolution_range_function(a, b) //
3232
},
3333
(Distribution::Function(a), Distribution::Function(b)) => {
3434
convolution_function_function(a, b) //
@@ -78,11 +78,16 @@ fn convolution_point_function(
7878
DistributionFunction::new(t, y)
7979
}
8080

81-
fn convolution_range_function(r: &DistributionRange<f64>, f: &DistributionFunction<f64>) -> Distribution {
81+
fn convolution_range_function(
82+
r: &DistributionRange<f64>,
83+
f: &DistributionFunction<f64>,
84+
) -> Result<Distribution, Report> {
85+
// check that the distribution function has uniform grid spacing
86+
let dx = compute_uniform_spacing(f.t())?;
87+
8288
// split in a convolution with
8389
// - a point distribution (taking care of the shift + amplitude)
8490
// - an interval centered on zero and of a fixed width (taking care of the smoothing)
85-
8691
let shift = f64::midpoint(r.start(), r.end());
8792
let amplitude = r.amplitude();
8893
let half_width = (r.end() - r.start()) / 2.0;
@@ -94,13 +99,14 @@ fn convolution_range_function(r: &DistributionRange<f64>, f: &DistributionFuncti
9499
let t_out = shifted_function.t().clone();
95100
let mut y_out = Array1::zeros(shifted_function.y().len());
96101

102+
// TODO: optimize by using cumulative sums
97103
for (i, &ti) in shifted_function.t().iter().enumerate() {
98104
let mask = shifted_function.t().mapv(|x| (x - ti).abs() <= half_width);
99105
let filtered_y = shifted_function.y() * &mask.mapv(|x| if x { 1.0 } else { 0.0 });
100-
y_out[i] = filtered_y.sum();
106+
y_out[i] = filtered_y.sum() * dx;
101107
}
102108

103-
Distribution::function(t_out, y_out).unwrap()
109+
Distribution::function(t_out, y_out)
104110
}
105111

106112
fn convolution_function_function(
@@ -291,15 +297,15 @@ mod tests {
291297

292298
#[test]
293299
fn test_convolution_range_function() {
294-
let r = Distribution::range((1.0, 3.0), 2.0);
300+
let r = Distribution::range((2.0, 6.0), 0.5);
295301

296-
let x = array![0.0, 1.0, 2.0, 3.0, 4.0, 5.0];
302+
let x = array![0.0, 2.0, 4.0, 6.0, 8.0, 10.0];
297303
let y = array![0.0, 1.0, 0.0, 2.0, 1.0, 0.0];
298304
let f = Distribution::function(x, y).unwrap();
299305
let actual = distribution_convolution(&r, &f).unwrap();
300306

301-
let x = array![2.0, 3.0, 4.0, 5.0, 6.0, 7.0];
302-
let y = array![2.0, 2.0, 6.0, 6.0, 6.0, 2.0];
307+
let x = array![4.0, 6.0, 8.0, 10.0, 12.0, 14.0];
308+
let y = array![1.0, 1.0, 3.0, 3.0, 3.0, 1.0];
303309
let expected = Distribution::function(x, y).unwrap();
304310

305311
assert_eq!(expected, actual);

0 commit comments

Comments
 (0)