Skip to content

Commit cecf691

Browse files
committed
refactor: simplify convolution function signatures
- change the signature of the `convolve` method, simplifying the underlying logic. Input/output grid are not arguments anymore, and only the grid size is passed. - adapting the testing to this new signature
1 parent 9cb9c74 commit cecf691

File tree

10 files changed

+44
-228
lines changed

10 files changed

+44
-228
lines changed

packages/treetime-convolution/src/algos/algos.rs

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,5 @@ impl ConvolutionAlgorithm {
6464
pub trait Algo: Send + Sync {
6565
fn name(&self) -> &'static str;
6666

67-
fn convolve(
68-
&self,
69-
input_grid: &Array1<f64>,
70-
f_values: &Array1<f64>,
71-
g_values: &Array1<f64>,
72-
output_grid: &Array1<f64>,
73-
) -> Result<Array1<f64>, Report>;
67+
fn convolve(&self, dx: f64, f_values: &Array1<f64>, g_values: &Array1<f64>) -> Result<Array1<f64>, Report>;
7468
}

packages/treetime-convolution/src/algos/ndarray_conv/ndarray_conv.rs

Lines changed: 6 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,37 +2,13 @@ use crate::algos::algos::Algo;
22
use eyre::Report;
33
use ndarray::Array1;
44
use ndarray_conv::{ConvExt, ConvMode, PaddingMode};
5-
use ndarray_interp::interp1d::Interp1DBuilder;
6-
use treetime_utils::ndarray::is_uniform_grid;
75

86
/// Convolution using ndarray_conv library for uniform grids
9-
pub fn convolve_ndarray_conv(
10-
input_grid: &Array1<f64>,
11-
f_values: &Array1<f64>,
12-
g_values: &Array1<f64>,
13-
output_grid: &Array1<f64>,
14-
) -> Result<Array1<f64>, Report> {
15-
debug_assert!(is_uniform_grid(input_grid), "input_grid must be uniform");
16-
debug_assert!(is_uniform_grid(output_grid), "output_grid must be uniform");
17-
18-
let grid_spacing = input_grid[1] - input_grid[0];
19-
7+
/// takes as input the grid spacing, and arrays with function values
8+
pub fn convolve_ndarray_conv(dx: f64, f_values: &Array1<f64>, g_values: &Array1<f64>) -> Result<Array1<f64>, Report> {
209
let discrete_conv = f_values.conv(g_values, ConvMode::Full, PaddingMode::Zeros)?;
21-
let continuous_conv = &discrete_conv * grid_spacing;
22-
23-
let conv_result_len = discrete_conv.len();
24-
let conv_result_min = input_grid[0] + input_grid[0];
25-
let conv_result_max = input_grid[input_grid.len() - 1] + input_grid[input_grid.len() - 1];
26-
let conv_result_grid = Array1::linspace(conv_result_min, conv_result_max, conv_result_len);
27-
28-
let temp_interp = Interp1DBuilder::new(continuous_conv).x(conv_result_grid).build()?;
29-
30-
let mut result = Array1::zeros(output_grid.len());
31-
for (i, &x) in output_grid.iter().enumerate() {
32-
result[i] = temp_interp.interp_scalar(x).unwrap_or(0.0);
33-
}
34-
35-
Ok(result)
10+
let continuous_conv = &discrete_conv * dx;
11+
Ok(continuous_conv)
3612
}
3713

3814
pub struct NdarrayAlgo;
@@ -42,13 +18,7 @@ impl Algo for NdarrayAlgo {
4218
"ndarray-conv"
4319
}
4420

45-
fn convolve(
46-
&self,
47-
input_grid: &Array1<f64>,
48-
f_values: &Array1<f64>,
49-
g_values: &Array1<f64>,
50-
output_grid: &Array1<f64>,
51-
) -> Result<Array1<f64>, Report> {
52-
convolve_ndarray_conv(input_grid, f_values, g_values, output_grid)
21+
fn convolve(&self, dx: f64, f_values: &Array1<f64>, g_values: &Array1<f64>) -> Result<Array1<f64>, Report> {
22+
convolve_ndarray_conv(dx, f_values, g_values)
5323
}
5424
}

packages/treetime-convolution/src/algos/ndarray_conv_fft/ndarray_conv_fft.rs

Lines changed: 7 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,36 +2,17 @@ use crate::algos::algos::Algo;
22
use eyre::Report;
33
use ndarray::Array1;
44
use ndarray_conv::{ConvFFTExt, ConvMode, PaddingMode};
5-
use ndarray_interp::interp1d::Interp1DBuilder;
6-
use treetime_utils::ndarray::is_uniform_grid;
75

6+
/// Convolution using ndarray_conv library with FFT for uniform grids.
7+
/// Takes as input the grid spacing, and arrays with function values
88
pub fn convolve_ndarray_conv_fft(
9-
input_grid: &Array1<f64>,
9+
dx: f64,
1010
f_values: &Array1<f64>,
1111
g_values: &Array1<f64>,
12-
output_grid: &Array1<f64>,
1312
) -> Result<Array1<f64>, Report> {
14-
debug_assert!(is_uniform_grid(input_grid), "input_grid must be uniform");
15-
debug_assert!(is_uniform_grid(output_grid), "output_grid must be uniform");
16-
17-
let grid_spacing = input_grid[1] - input_grid[0];
18-
1913
let discrete_conv = f_values.conv_fft(g_values, ConvMode::Full, PaddingMode::Zeros)?;
20-
let continuous_conv = &discrete_conv * grid_spacing;
21-
22-
let conv_result_len = discrete_conv.len();
23-
let conv_result_min = input_grid[0] + input_grid[0];
24-
let conv_result_max = input_grid[input_grid.len() - 1] + input_grid[input_grid.len() - 1];
25-
let conv_result_grid = Array1::linspace(conv_result_min, conv_result_max, conv_result_len);
26-
27-
let temp_interp = Interp1DBuilder::new(continuous_conv).x(conv_result_grid).build()?;
28-
29-
let mut result = Array1::zeros(output_grid.len());
30-
for (i, &x) in output_grid.iter().enumerate() {
31-
result[i] = temp_interp.interp_scalar(x).unwrap_or(0.0);
32-
}
33-
34-
Ok(result)
14+
let continuous_conv = &discrete_conv * dx;
15+
Ok(continuous_conv)
3516
}
3617

3718
/// Ndarray FFT-based convolution algorithm
@@ -42,13 +23,7 @@ impl Algo for NdarrayConvFftAlgo {
4223
"ndarray-conv-fft"
4324
}
4425

45-
fn convolve(
46-
&self,
47-
input_grid: &Array1<f64>,
48-
f_values: &Array1<f64>,
49-
g_values: &Array1<f64>,
50-
output_grid: &Array1<f64>,
51-
) -> Result<Array1<f64>, Report> {
52-
convolve_ndarray_conv_fft(input_grid, f_values, g_values, output_grid)
26+
fn convolve(&self, dx: f64, f_values: &Array1<f64>, g_values: &Array1<f64>) -> Result<Array1<f64>, Report> {
27+
convolve_ndarray_conv_fft(dx, f_values, g_values)
5328
}
5429
}
Lines changed: 8 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,15 @@
11
use crate::algos::algos::Algo;
22
use eyre::Report;
3-
use itertools::izip;
43
use ndarray::Array1;
5-
use ndarray_interp::interp1d::Interp1DBuilder;
64

7-
/// Convolution using Riemann sum integration
8-
pub fn convolve_riemann(
9-
input_grid: &Array1<f64>,
10-
f_values: &Array1<f64>,
11-
g_values: &Array1<f64>,
12-
output_grid: &Array1<f64>,
13-
) -> Result<Array1<f64>, Report> {
14-
let g_interp = Interp1DBuilder::new(g_values.view()).x(input_grid.view()).build()?;
5+
/// Riemann convolution algorithm on uniform grids
6+
pub fn convolve_riemann(dx: f64, f_values: &Array1<f64>, g_values: &Array1<f64>) -> Result<Array1<f64>, Report> {
7+
let mut result = Array1::zeros(f_values.len() + g_values.len() - 1);
158

16-
let ds = input_grid[1] - input_grid[0];
17-
18-
let mut result = Array1::zeros(output_grid.len());
19-
for (output_idx, &x_eval) in output_grid.iter().enumerate() {
20-
let mut sum = 0.0;
21-
for (&x_input, &f_at_x_input) in izip!(input_grid, f_values) {
22-
let g_at_shifted = g_interp.interp_scalar(x_eval - x_input).unwrap_or(0.0);
23-
sum += f_at_x_input * g_at_shifted;
9+
for (i, &f_val) in f_values.iter().enumerate() {
10+
for (j, &g_val) in g_values.iter().enumerate() {
11+
result[i + j] += f_val * g_val * dx;
2412
}
25-
result[output_idx] = sum * ds;
2613
}
2714

2815
Ok(result)
@@ -35,13 +22,7 @@ impl Algo for RiemannAlgo {
3522
"riemann"
3623
}
3724

38-
fn convolve(
39-
&self,
40-
input_grid: &Array1<f64>,
41-
f_values: &Array1<f64>,
42-
g_values: &Array1<f64>,
43-
output_grid: &Array1<f64>,
44-
) -> Result<Array1<f64>, Report> {
45-
convolve_riemann(input_grid, f_values, g_values, output_grid)
25+
fn convolve(&self, dx: f64, f_values: &Array1<f64>, g_values: &Array1<f64>) -> Result<Array1<f64>, Report> {
26+
convolve_riemann(dx, f_values, g_values)
4627
}
4728
}

packages/treetime-convolution/src/testing/framework/test_case.rs

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,4 @@ pub trait TestCase: Clone + Send + Sync + Serialize {
1414
fn input_grid_domain(&self) -> (f64, f64);
1515

1616
fn input_grid_n_points(&self) -> usize;
17-
18-
fn output_grid_domain(&self) -> (f64, f64);
19-
20-
fn output_grid_n_points(&self) -> usize;
2117
}

packages/treetime-convolution/src/testing/run.rs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -262,19 +262,20 @@ fn run_test<S: TestSuite>(
262262
let input_grid_n_points = test_case.input_grid_n_points();
263263
let input_grid = Array1::linspace(input_grid_min, input_grid_max, input_grid_n_points);
264264

265-
let (output_grid_min, output_grid_max) = test_case.output_grid_domain();
266-
let output_grid_n_points = test_case.output_grid_n_points();
267-
let output_grid = Array1::linspace(output_grid_min, output_grid_max, output_grid_n_points);
265+
let (evaluation_grid_min, evaluation_grid_max) = (input_grid_min * 2.0, input_grid_max * 2.0);
266+
let evaluation_grid_n_points = 2 * input_grid_n_points - 1;
267+
let evaluation_grid = Array1::linspace(evaluation_grid_min, evaluation_grid_max, evaluation_grid_n_points);
268268

269269
let f_values = suite.create_f(test_case, &input_grid)?;
270270
let g_values = suite.create_g(test_case, &input_grid)?;
271271

272-
let actual_values = algo.convolve(&input_grid, &f_values, &g_values, &output_grid)?;
273-
let expected_values = suite.analytical_convolution(test_case, &output_grid)?;
272+
let dx = input_grid[1] - input_grid[0];
273+
let actual_values = algo.convolve(dx, &f_values, &g_values)?;
274+
let expected_values = suite.analytical_convolution(test_case, &evaluation_grid)?;
274275

275276
let execution_time = start_time.elapsed().as_secs_f64() * 1000.0;
276277

277-
let metrics = ConvolutionMetrics::new(&output_grid, &actual_values, &expected_values, execution_time)?;
278+
let metrics = ConvolutionMetrics::new(&evaluation_grid, &actual_values, &expected_values, execution_time)?;
278279

279280
Ok(TestResult {
280281
algorithm: algo.name().to_owned(),
@@ -284,7 +285,7 @@ fn run_test<S: TestSuite>(
284285
f_y_values: f_values,
285286
g_x_values: input_grid,
286287
g_y_values: g_values,
287-
evaluation_grid: output_grid,
288+
evaluation_grid,
288289
actual_values,
289290
expected_values,
290291
metrics,

packages/treetime-convolution/src/testing/test_suites/exponential.rs

Lines changed: 6 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,6 @@ impl TestSuite for ExponentialTestSuite {
5151
b: 2.0,
5252
input_grid_domain: (-1.0, 10.0),
5353
input_grid_n_points: 1101,
54-
output_grid_domain: (-1.0, 10.0),
55-
output_grid_n_points: 1101,
5654
},
5755
ExponentialTestCase {
5856
name: "moderate_coarse_grid".to_owned(),
@@ -65,8 +63,6 @@ impl TestSuite for ExponentialTestSuite {
6563
b: 0.8,
6664
input_grid_domain: (0.0, 20.0),
6765
input_grid_n_points: 201,
68-
output_grid_domain: (0.0, 30.0),
69-
output_grid_n_points: 301,
7066
},
7167
ExponentialTestCase {
7268
name: "tight_truncation".to_owned(),
@@ -80,8 +76,6 @@ impl TestSuite for ExponentialTestSuite {
8076
b: 2.0,
8177
input_grid_domain: (0.0, 5.0),
8278
input_grid_n_points: 501,
83-
output_grid_domain: (0.0, 6.0),
84-
output_grid_n_points: 601,
8579
},
8680
ExponentialTestCase {
8781
name: "baseline_distinct_rates".to_owned(),
@@ -94,8 +88,6 @@ impl TestSuite for ExponentialTestSuite {
9488
b: 2.0,
9589
input_grid_domain: (0.0, 10.0),
9690
input_grid_n_points: 1001,
97-
output_grid_domain: (0.0, 17.0),
98-
output_grid_n_points: 1701,
9991
},
10092
ExponentialTestCase {
10193
name: "equal_rates_limit".to_owned(),
@@ -107,8 +99,6 @@ impl TestSuite for ExponentialTestSuite {
10799
b: 1.5,
108100
input_grid_domain: (0.0, 12.0),
109101
input_grid_n_points: 1201,
110-
output_grid_domain: (0.0, 24.0),
111-
output_grid_n_points: 2401,
112102
},
113103
ExponentialTestCase {
114104
name: "near_equal_rates".to_owned(),
@@ -121,8 +111,6 @@ impl TestSuite for ExponentialTestSuite {
121111
b: 1.0 + 1e-6,
122112
input_grid_domain: (0.0, 30.0),
123113
input_grid_n_points: 15001,
124-
output_grid_domain: (0.0, 60.0),
125-
output_grid_n_points: 30001,
126114
},
127115
ExponentialTestCase {
128116
name: "fast_slow_decay".to_owned(),
@@ -135,8 +123,6 @@ impl TestSuite for ExponentialTestSuite {
135123
b: 0.1,
136124
input_grid_domain: (0.0, 200.0),
137125
input_grid_n_points: 200001,
138-
output_grid_domain: (0.0, 202.0),
139-
output_grid_n_points: 202001,
140126
},
141127
ExponentialTestCase {
142128
name: "slow_fast_decay".to_owned(),
@@ -149,8 +135,6 @@ impl TestSuite for ExponentialTestSuite {
149135
b: 10.0,
150136
input_grid_domain: (0.0, 200.0),
151137
input_grid_n_points: 200001,
152-
output_grid_domain: (0.0, 202.0),
153-
output_grid_n_points: 202001,
154138
},
155139
ExponentialTestCase {
156140
name: "fine_grid_reference".to_owned(),
@@ -163,8 +147,6 @@ impl TestSuite for ExponentialTestSuite {
163147
b: 0.8,
164148
input_grid_domain: (0.0, 20.0),
165149
input_grid_n_points: 10001,
166-
output_grid_domain: (0.0, 40.0),
167-
output_grid_n_points: 20001,
168150
},
169151
ExponentialTestCase {
170152
name: "long_horizon_tail_precision".to_owned(),
@@ -176,8 +158,6 @@ impl TestSuite for ExponentialTestSuite {
176158
b: 0.04,
177159
input_grid_domain: (0.0, 750.0),
178160
input_grid_n_points: 75001,
179-
output_grid_domain: (0.0, 1350.0),
180-
output_grid_n_points: 135001,
181161
},
182162
ExponentialTestCase {
183163
name: "large_range_underflow_guard".to_owned(),
@@ -190,8 +170,6 @@ impl TestSuite for ExponentialTestSuite {
190170
b: 0.7,
191171
input_grid_domain: (0.0, 400.0),
192172
input_grid_n_points: 20001,
193-
output_grid_domain: (0.0, 800.0),
194-
output_grid_n_points: 40001,
195173
},
196174
ExponentialTestCase {
197175
name: "extreme_near_equality".to_owned(),
@@ -204,8 +182,8 @@ impl TestSuite for ExponentialTestSuite {
204182
b: 1.0 + 1e-12,
205183
input_grid_domain: (0.0, 80.0),
206184
input_grid_n_points: 16001,
207-
output_grid_domain: (0.0, 160.0),
208-
output_grid_n_points: 32001,
185+
186+
209187
},
210188
ExponentialTestCase {
211189
name: "very_fast_decays".to_owned(),
@@ -217,8 +195,8 @@ impl TestSuite for ExponentialTestSuite {
217195
b: 30.0,
218196
input_grid_domain: (0.0, 0.5),
219197
input_grid_n_points: 1001,
220-
output_grid_domain: (0.0, 1.0),
221-
output_grid_n_points: 2001,
198+
199+
222200
},
223201
]
224202
}
@@ -236,8 +214,8 @@ pub struct ExponentialTestCase {
236214
pub b: f64,
237215
pub input_grid_domain: (f64, f64),
238216
pub input_grid_n_points: usize,
239-
pub output_grid_domain: (f64, f64),
240-
pub output_grid_n_points: usize,
217+
218+
241219
}
242220

243221
impl TestCase for ExponentialTestCase {
@@ -269,11 +247,4 @@ impl TestCase for ExponentialTestCase {
269247
self.input_grid_n_points
270248
}
271249

272-
fn output_grid_domain(&self) -> (f64, f64) {
273-
self.output_grid_domain
274-
}
275-
276-
fn output_grid_n_points(&self) -> usize {
277-
self.output_grid_n_points
278-
}
279250
}

0 commit comments

Comments
 (0)