@@ -209,6 +209,7 @@ mod tests {
209209 use super :: * ;
210210
211211 use pretty_assertions:: assert_eq;
212+ use treetime_utils:: assert_error;
212213
213214 #[ test]
214215 fn test_convolution_empty ( ) {
@@ -315,17 +316,13 @@ mod tests {
315316 let b_y = array ! [ 1.0 , 2.0 ] ; // Make values non-uniform to force Function type
316317 let b = Distribution :: function ( b_x, b_y) . unwrap ( ) ;
317318
318- let result = distribution_convolution ( & a, & b) . unwrap ( ) ;
319+ let actual = distribution_convolution ( & a, & b) . unwrap ( ) ;
320+
321+ let expected_x = array ! [ 0.0 , 1.0 , 2.0 , 3.0 ] ;
322+ let expected_y = array ! [ 1.0 , 4.0 , 5.0 , 2.0 ] ;
323+ let expected = Distribution :: function ( expected_x, expected_y) . unwrap ( ) ;
319324
320- match result {
321- Distribution :: Function ( f) => {
322- assert_eq ! ( f. t( ) . len( ) , 4 ) ; // 3 + 2 - 1 = 4
323- assert ! ( f. t( ) [ 0 ] >= 0.0 ) ;
324- assert ! ( f. t( ) . iter( ) . all( |& x| x. is_finite( ) ) ) ;
325- assert ! ( f. y( ) . iter( ) . all( |& y| y. is_finite( ) && y >= 0.0 ) ) ;
326- } ,
327- other => panic ! ( "Expected Function distribution, got {other:?}" ) ,
328- }
325+ assert_eq ! ( expected, actual) ;
329326 }
330327
331328 #[ test]
@@ -338,9 +335,9 @@ mod tests {
338335 let b_y = array ! [ 4.0 ] ;
339336 let b = Distribution :: function ( b_x, b_y) . unwrap ( ) ;
340337
341- let result = distribution_convolution ( & a, & b) . unwrap ( ) ;
338+ let actual = distribution_convolution ( & a, & b) . unwrap ( ) ;
342339 let expected = Distribution :: point ( 7.0 , 12.0 ) ;
343- assert_eq ! ( expected, result ) ;
340+ assert_eq ! ( expected, actual ) ;
344341 }
345342
346343 #[ test]
@@ -350,8 +347,9 @@ mod tests {
350347 let b_y = array ! [ 1.0 , 1.0 ] ;
351348 let b = Distribution :: function ( b_x, b_y) . unwrap ( ) ;
352349
353- let result = distribution_convolution ( & a, & b) . unwrap ( ) ;
354- assert_eq ! ( Distribution :: empty( ) , result) ;
350+ let actual = distribution_convolution ( & a, & b) . unwrap ( ) ;
351+ let expected = Distribution :: empty ( ) ;
352+ assert_eq ! ( expected, actual) ;
355353 }
356354
357355 #[ test]
@@ -365,20 +363,13 @@ mod tests {
365363 let b_y = array ! [ 1.0 , 1.0 , 1.0 ] ;
366364 let b = Distribution :: function ( b_x, b_y) . unwrap ( ) ;
367365
368- let result = distribution_convolution ( & a, & b) . unwrap ( ) ;
366+ let actual = distribution_convolution ( & a, & b) . unwrap ( ) ;
367+
368+ let expected_x = array ! [ 0.0 , 1.0 , 2.0 , 3.0 ] ;
369+ let expected_y = array ! [ 1.0 , 2.0 , 2.0 , 1.0 ] ;
370+ let expected = Distribution :: function ( expected_x, expected_y) . unwrap ( ) ;
369371
370- match result {
371- Distribution :: Function ( f) => {
372- // Should handle different spacings correctly
373- assert ! ( f. t( ) . len( ) >= 3 ) ;
374- assert ! ( f. t( ) . iter( ) . all( |& x| x. is_finite( ) ) ) ;
375- assert ! ( f. y( ) . iter( ) . all( |& y| y. is_finite( ) && y >= 0.0 ) ) ;
376- // Result should span from 0+0 to 1+2 = 3
377- assert ! ( f. t( ) [ 0 ] >= 0.0 - 1e-10 ) ;
378- assert ! ( * f. t( ) . last( ) . unwrap( ) <= 3.0 + 1e-10 ) ;
379- } ,
380- other => panic ! ( "Expected Function distribution, got {other:?}" ) ,
381- }
372+ assert_eq ! ( expected, actual) ;
382373 }
383374
384375 #[ test]
@@ -390,16 +381,13 @@ mod tests {
390381 let b_y = array ! [ 2.0 , 3.0 ] ;
391382 let b = Distribution :: function ( b_x, b_y) . unwrap ( ) ;
392383
393- let result = distribution_convolution ( & a, & b) . unwrap ( ) ;
384+ let actual = distribution_convolution ( & a, & b) . unwrap ( ) ;
385+
386+ let expected_x = array ! [ 6.0 , 7.0 ] ;
387+ let expected_y = array ! [ 2.0 , 3.0 ] ;
388+ let expected = Distribution :: function ( expected_x, expected_y) . unwrap ( ) ;
394389
395- match result {
396- Distribution :: Function ( f) => {
397- // Should shift the function by the point's position
398- assert ! ( f. t( ) [ 0 ] >= 6.0 - 1e-10 ) ; // 5 + 1
399- assert ! ( * f. t( ) . last( ) . unwrap( ) <= 7.0 + 1e-10 ) ; // 5 + 2
400- } ,
401- other => panic ! ( "Expected Function distribution, got {other:?}" ) ,
402- }
390+ assert_eq ! ( expected, actual) ;
403391 }
404392
405393 #[ test]
@@ -426,17 +414,10 @@ mod tests {
426414
427415 // In backward pass, we negate the branch distribution
428416 let negated_branch = branch_length_dist. negate ( ) ;
429- let parent_dist = distribution_convolution ( & child_time_dist, & negated_branch) . unwrap ( ) ;
417+ let actual = distribution_convolution ( & child_time_dist, & negated_branch) . unwrap ( ) ;
430418
431- match parent_dist {
432- Distribution :: Point ( p) => {
433- let parent_time = p. t ( ) ;
434- let expected_parent_time = 2013.0 - 2.5 ; // 2010.5
435- assert ! ( ( parent_time - expected_parent_time) . abs( ) < 1e-10 ) ;
436- assert ! ( parent_time < 2013.0 , "Parent should be older (earlier time) than child" ) ;
437- } ,
438- other => panic ! ( "Expected point distribution, got {other:?}" ) ,
439- }
419+ let expected = Distribution :: point ( 2010.5 , 1.0 ) ;
420+ assert_eq ! ( expected, actual) ;
440421 }
441422
442423 #[ test]
@@ -445,17 +426,10 @@ mod tests {
445426 let parent_time_dist = Distribution :: point ( 2010.0 , 1.0 ) ;
446427 let branch_length_dist = Distribution :: point ( 1.5 , 1.0 ) ;
447428
448- let child_dist = distribution_convolution ( & parent_time_dist, & branch_length_dist) . unwrap ( ) ;
429+ let actual = distribution_convolution ( & parent_time_dist, & branch_length_dist) . unwrap ( ) ;
449430
450- match child_dist {
451- Distribution :: Point ( p) => {
452- let child_time = p. t ( ) ;
453- let expected_child_time = 2010.0 + 1.5 ; // 2011.5
454- assert ! ( ( child_time - expected_child_time) . abs( ) < 1e-10 ) ;
455- assert ! ( child_time > 2010.0 , "Child should be younger (later time) than parent" ) ;
456- } ,
457- other => panic ! ( "Expected point distribution, got {other:?}" ) ,
458- }
431+ let expected = Distribution :: point ( 2011.5 , 1.0 ) ;
432+ assert_eq ! ( expected, actual) ;
459433 }
460434
461435 #[ test]
@@ -469,17 +443,12 @@ mod tests {
469443 let branch_y = array ! [ 0.3 , 0.4 , 0.3 ] ; // Uncertainty around 1.5
470444 let branch_dist = Distribution :: function ( branch_x, branch_y) . unwrap ( ) ;
471445
472- let child_dist = distribution_convolution ( & parent_dist, & branch_dist) . unwrap ( ) ;
473-
474- match child_dist {
475- Distribution :: Function ( f) => {
476- // Result should span from 2010+1=2011 to 2011+2=2013
477- assert ! ( f. t( ) [ 0 ] >= 2011.0 - 1e-10 ) ;
478- assert ! ( * f. t( ) . last( ) . unwrap( ) <= 2013.0 + 1e-10 ) ;
479- // All probabilities should be non-negative
480- assert ! ( f. y( ) . iter( ) . all( |& y| y >= 0.0 ) ) ;
481- } ,
482- other => panic ! ( "Expected function distribution, got {other:?}" ) ,
483- }
446+ let actual = distribution_convolution ( & parent_dist, & branch_dist) . unwrap ( ) ;
447+
448+ let expected_x = array ! [ 2011.0 , 2011.5 , 2012.0 , 2012.5 , 2013.0 ] ;
449+ let expected_y = array ! [ 0.03 , 0.13 , 0.18 , 0.13 , 0.03 ] ;
450+ let expected = Distribution :: function ( expected_x, expected_y) . unwrap ( ) ;
451+
452+ assert_eq ! ( expected, actual) ;
484453 }
485454}
0 commit comments