Skip to content

Commit 0600593

Browse files
refactor: simplify test assertions
1 parent eed0881 commit 0600593

File tree

3 files changed

+98
-193
lines changed

3 files changed

+98
-193
lines changed

packages/treetime/src/distribution/distribution_convolution.rs

Lines changed: 37 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ mod tests {
209209
use super::*;
210210

211211
use pretty_assertions::assert_eq;
212+
use treetime_utils::assert_error;
212213

213214
#[test]
214215
fn test_convolution_empty() {
@@ -315,17 +316,13 @@ mod tests {
315316
let b_y = array![1.0, 2.0]; // Make values non-uniform to force Function type
316317
let b = Distribution::function(b_x, b_y).unwrap();
317318

318-
let result = distribution_convolution(&a, &b).unwrap();
319+
let actual = distribution_convolution(&a, &b).unwrap();
320+
321+
let expected_x = array![0.0, 1.0, 2.0, 3.0];
322+
let expected_y = array![1.0, 4.0, 5.0, 2.0];
323+
let expected = Distribution::function(expected_x, expected_y).unwrap();
319324

320-
match result {
321-
Distribution::Function(f) => {
322-
assert_eq!(f.t().len(), 4); // 3 + 2 - 1 = 4
323-
assert!(f.t()[0] >= 0.0);
324-
assert!(f.t().iter().all(|&x| x.is_finite()));
325-
assert!(f.y().iter().all(|&y| y.is_finite() && y >= 0.0));
326-
},
327-
other => panic!("Expected Function distribution, got {other:?}"),
328-
}
325+
assert_eq!(expected, actual);
329326
}
330327

331328
#[test]
@@ -338,9 +335,9 @@ mod tests {
338335
let b_y = array![4.0];
339336
let b = Distribution::function(b_x, b_y).unwrap();
340337

341-
let result = distribution_convolution(&a, &b).unwrap();
338+
let actual = distribution_convolution(&a, &b).unwrap();
342339
let expected = Distribution::point(7.0, 12.0);
343-
assert_eq!(expected, result);
340+
assert_eq!(expected, actual);
344341
}
345342

346343
#[test]
@@ -350,8 +347,9 @@ mod tests {
350347
let b_y = array![1.0, 1.0];
351348
let b = Distribution::function(b_x, b_y).unwrap();
352349

353-
let result = distribution_convolution(&a, &b).unwrap();
354-
assert_eq!(Distribution::empty(), result);
350+
let actual = distribution_convolution(&a, &b).unwrap();
351+
let expected = Distribution::empty();
352+
assert_eq!(expected, actual);
355353
}
356354

357355
#[test]
@@ -365,20 +363,13 @@ mod tests {
365363
let b_y = array![1.0, 1.0, 1.0];
366364
let b = Distribution::function(b_x, b_y).unwrap();
367365

368-
let result = distribution_convolution(&a, &b).unwrap();
366+
let actual = distribution_convolution(&a, &b).unwrap();
367+
368+
let expected_x = array![0.0, 1.0, 2.0, 3.0];
369+
let expected_y = array![1.0, 2.0, 2.0, 1.0];
370+
let expected = Distribution::function(expected_x, expected_y).unwrap();
369371

370-
match result {
371-
Distribution::Function(f) => {
372-
// Should handle different spacings correctly
373-
assert!(f.t().len() >= 3);
374-
assert!(f.t().iter().all(|&x| x.is_finite()));
375-
assert!(f.y().iter().all(|&y| y.is_finite() && y >= 0.0));
376-
// Result should span from 0+0 to 1+2 = 3
377-
assert!(f.t()[0] >= 0.0 - 1e-10);
378-
assert!(*f.t().last().unwrap() <= 3.0 + 1e-10);
379-
},
380-
other => panic!("Expected Function distribution, got {other:?}"),
381-
}
372+
assert_eq!(expected, actual);
382373
}
383374

384375
#[test]
@@ -390,16 +381,13 @@ mod tests {
390381
let b_y = array![2.0, 3.0];
391382
let b = Distribution::function(b_x, b_y).unwrap();
392383

393-
let result = distribution_convolution(&a, &b).unwrap();
384+
let actual = distribution_convolution(&a, &b).unwrap();
385+
386+
let expected_x = array![6.0, 7.0];
387+
let expected_y = array![2.0, 3.0];
388+
let expected = Distribution::function(expected_x, expected_y).unwrap();
394389

395-
match result {
396-
Distribution::Function(f) => {
397-
// Should shift the function by the point's position
398-
assert!(f.t()[0] >= 6.0 - 1e-10); // 5 + 1
399-
assert!(*f.t().last().unwrap() <= 7.0 + 1e-10); // 5 + 2
400-
},
401-
other => panic!("Expected Function distribution, got {other:?}"),
402-
}
390+
assert_eq!(expected, actual);
403391
}
404392

405393
#[test]
@@ -426,17 +414,10 @@ mod tests {
426414

427415
// In backward pass, we negate the branch distribution
428416
let negated_branch = branch_length_dist.negate();
429-
let parent_dist = distribution_convolution(&child_time_dist, &negated_branch).unwrap();
417+
let actual = distribution_convolution(&child_time_dist, &negated_branch).unwrap();
430418

431-
match parent_dist {
432-
Distribution::Point(p) => {
433-
let parent_time = p.t();
434-
let expected_parent_time = 2013.0 - 2.5; // 2010.5
435-
assert!((parent_time - expected_parent_time).abs() < 1e-10);
436-
assert!(parent_time < 2013.0, "Parent should be older (earlier time) than child");
437-
},
438-
other => panic!("Expected point distribution, got {other:?}"),
439-
}
419+
let expected = Distribution::point(2010.5, 1.0);
420+
assert_eq!(expected, actual);
440421
}
441422

442423
#[test]
@@ -445,17 +426,10 @@ mod tests {
445426
let parent_time_dist = Distribution::point(2010.0, 1.0);
446427
let branch_length_dist = Distribution::point(1.5, 1.0);
447428

448-
let child_dist = distribution_convolution(&parent_time_dist, &branch_length_dist).unwrap();
429+
let actual = distribution_convolution(&parent_time_dist, &branch_length_dist).unwrap();
449430

450-
match child_dist {
451-
Distribution::Point(p) => {
452-
let child_time = p.t();
453-
let expected_child_time = 2010.0 + 1.5; // 2011.5
454-
assert!((child_time - expected_child_time).abs() < 1e-10);
455-
assert!(child_time > 2010.0, "Child should be younger (later time) than parent");
456-
},
457-
other => panic!("Expected point distribution, got {other:?}"),
458-
}
431+
let expected = Distribution::point(2011.5, 1.0);
432+
assert_eq!(expected, actual);
459433
}
460434

461435
#[test]
@@ -469,17 +443,12 @@ mod tests {
469443
let branch_y = array![0.3, 0.4, 0.3]; // Uncertainty around 1.5
470444
let branch_dist = Distribution::function(branch_x, branch_y).unwrap();
471445

472-
let child_dist = distribution_convolution(&parent_dist, &branch_dist).unwrap();
473-
474-
match child_dist {
475-
Distribution::Function(f) => {
476-
// Result should span from 2010+1=2011 to 2011+2=2013
477-
assert!(f.t()[0] >= 2011.0 - 1e-10);
478-
assert!(*f.t().last().unwrap() <= 2013.0 + 1e-10);
479-
// All probabilities should be non-negative
480-
assert!(f.y().iter().all(|&y| y >= 0.0));
481-
},
482-
other => panic!("Expected function distribution, got {other:?}"),
483-
}
446+
let actual = distribution_convolution(&parent_dist, &branch_dist).unwrap();
447+
448+
let expected_x = array![2011.0, 2011.5, 2012.0, 2012.5, 2013.0];
449+
let expected_y = array![0.03, 0.13, 0.18, 0.13, 0.03];
450+
let expected = Distribution::function(expected_x, expected_y).unwrap();
451+
452+
assert_eq!(expected, actual);
484453
}
485454
}

packages/treetime/src/distribution/distribution_division.rs

Lines changed: 24 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -99,16 +99,16 @@ fn divide_function_by_function(
9999
#[cfg(test)]
100100
mod tests {
101101
use super::*;
102-
use approx::assert_ulps_eq;
103102
use ndarray::array;
104103
use treetime_utils::assert_error;
105104

106105
#[test]
107106
fn test_divide_empty_by_any() {
108107
let empty = Distribution::empty();
109108
let point = Distribution::point(1.0, 2.0);
110-
let result = distribution_division(&empty, &point).unwrap();
111-
assert_eq!(result, Distribution::empty());
109+
let actual = distribution_division(&empty, &point).unwrap();
110+
let expected = Distribution::empty();
111+
assert_eq!(expected, actual);
112112
}
113113

114114
#[test]
@@ -128,15 +128,9 @@ mod tests {
128128
let y = array![1.0, 2.0, 5.0, 4.0, 3.0];
129129
let func = Distribution::function(t, y).unwrap();
130130

131-
let result = distribution_division(&point, &func).unwrap();
132-
133-
match result {
134-
Distribution::Point(p) => {
135-
assert_ulps_eq!(p.t(), 2.0);
136-
assert_ulps_eq!(p.amplitude(), 2.0);
137-
},
138-
_ => panic!("Expected Point distribution"),
139-
}
131+
let actual = distribution_division(&point, &func).unwrap();
132+
let expected = Distribution::point(2.0, 2.0);
133+
assert_eq!(expected, actual);
140134
}
141135

142136
#[test]
@@ -146,20 +140,13 @@ mod tests {
146140
let y2 = array![2.0, 4.0, 5.0, 8.0, 10.0];
147141

148142
let dividend = Distribution::function(t.clone(), y1).unwrap();
149-
let divisor = Distribution::function(t, y2).unwrap();
143+
let divisor = Distribution::function(t.clone(), y2).unwrap();
150144

151-
let result = distribution_division(&dividend, &divisor).unwrap();
145+
let actual = distribution_division(&dividend, &divisor).unwrap();
152146

153-
match result {
154-
Distribution::Function(f) => {
155-
assert_ulps_eq!(f.y()[0], 5.0);
156-
assert_ulps_eq!(f.y()[1], 5.0);
157-
assert_ulps_eq!(f.y()[2], 6.0);
158-
assert_ulps_eq!(f.y()[3], 5.0);
159-
assert_ulps_eq!(f.y()[4], 5.0);
160-
},
161-
_ => panic!("Expected Function distribution"),
162-
}
147+
let expected_y = array![5.0, 5.0, 6.0, 5.0, 5.0];
148+
let expected = Distribution::function(t, expected_y).unwrap();
149+
assert_eq!(expected, actual);
163150
}
164151

165152
#[test]
@@ -169,19 +156,13 @@ mod tests {
169156
let y2 = array![2.0, 0.0, 5.0];
170157

171158
let dividend = Distribution::function(t.clone(), y1).unwrap();
172-
let divisor = Distribution::function(t, y2).unwrap();
159+
let divisor = Distribution::function(t.clone(), y2).unwrap();
173160

174-
let result = distribution_division(&dividend, &divisor).unwrap();
175-
// Should succeed with TINY_NUMBER handling for zero divisor, maintaining uniform grid
176-
match result {
177-
Distribution::Function(f) => {
178-
assert_eq!(f.t().len(), 3);
179-
assert_ulps_eq!(f.y()[0], 5.0); // 10.0 / 2.0
180-
assert!(f.y()[1] > 1e9); // 20.0 / TINY_NUMBER (very large)
181-
assert_ulps_eq!(f.y()[2], 6.0); // 30.0 / 5.0
182-
},
183-
_ => panic!("Expected Function distribution"),
184-
}
161+
let actual = distribution_division(&dividend, &divisor).unwrap();
162+
163+
let expected_y = array![5.0, 20.0 / TINY_NUMBER, 6.0];
164+
let expected = Distribution::function(t, expected_y).unwrap();
165+
assert_eq!(expected, actual);
185166
}
186167

187168
#[test]
@@ -191,13 +172,15 @@ mod tests {
191172
let y = array![1.0, 2.0, 5.0, 4.0, 3.0];
192173
let func = Distribution::function(t, y).unwrap();
193174

194-
let result = distribution_division(&range, &func).unwrap();
175+
let actual = distribution_division(&range, &func).unwrap();
195176

196-
match result {
177+
// Since this creates a sampled function with 100 points, we verify it's a function with correct properties
178+
match actual {
197179
Distribution::Function(f) => {
198-
assert!(f.t().len() > 0);
199-
assert!(f.y().len() > 0);
200-
assert_eq!(f.t().len(), f.y().len());
180+
assert_eq!(f.t().len(), 100);
181+
assert_eq!(f.y().len(), 100);
182+
assert!(f.t()[0] >= 1.0 - 1e-10);
183+
assert!(f.t()[99] <= 3.0 + 1e-10);
201184
},
202185
_ => panic!("Expected Function distribution"),
203186
}

0 commit comments

Comments
 (0)