Skip to content

Commit a0f6931

Browse files
test: dump intermediate graph states
1 parent 4794774 commit a0f6931

File tree

5 files changed

+31
-56
lines changed

5 files changed

+31
-56
lines changed

packages/treetime-convolution/src/grid_fn.rs

Lines changed: 17 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -25,74 +25,40 @@ impl<T: InterpElem> PartialEq for GridFn<T> {
2525

2626
impl<T: InterpElem> Eq for GridFn<T> {}
2727

28-
impl<T: InterpElem> Serialize for GridFn<T> {
28+
impl<T> Serialize for GridFn<T>
29+
where
30+
T: InterpElem + Serialize,
31+
{
2932
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
3033
where
3134
S: Serializer,
3235
{
3336
use serde::ser::SerializeStruct;
3437
let mut state = serializer.serialize_struct("GridFn", 2)?;
35-
state.serialize_field("x", self.x())?;
36-
state.serialize_field("y", self.y())?;
38+
state.serialize_field("x", &self.x().to_vec())?;
39+
state.serialize_field("y", &self.y().to_vec())?;
3740
state.end()
3841
}
3942
}
4043

41-
impl<'de> Deserialize<'de> for GridFn<f64> {
44+
impl<'de, T> Deserialize<'de> for GridFn<T>
45+
where
46+
T: InterpElem + Deserialize<'de>,
47+
{
4248
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
4349
where
4450
D: Deserializer<'de>,
4551
{
46-
use serde::de::{self, MapAccess, Visitor};
47-
use std::fmt;
48-
4952
#[derive(Deserialize)]
50-
#[serde(field_identifier, rename_all = "lowercase")]
51-
enum Field {
52-
X,
53-
Y,
54-
}
55-
56-
struct GridFnVisitor;
57-
58-
impl<'de> Visitor<'de> for GridFnVisitor {
59-
type Value = GridFn<f64>;
60-
61-
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
62-
formatter.write_str("struct GridFn")
63-
}
64-
65-
fn visit_map<V>(self, mut map: V) -> Result<GridFn<f64>, V::Error>
66-
where
67-
V: MapAccess<'de>,
68-
{
69-
let mut x: Option<Array1<f64>> = None;
70-
let mut y: Option<Array1<f64>> = None;
71-
while let Some(key) = map.next_key()? {
72-
match key {
73-
Field::X => {
74-
if x.is_some() {
75-
return Err(de::Error::duplicate_field("x"));
76-
}
77-
x = Some(map.next_value()?);
78-
},
79-
Field::Y => {
80-
if y.is_some() {
81-
return Err(de::Error::duplicate_field("y"));
82-
}
83-
y = Some(map.next_value()?);
84-
},
85-
}
86-
}
87-
let x = x.ok_or_else(|| de::Error::missing_field("x"))?;
88-
let y = y.ok_or_else(|| de::Error::missing_field("y"))?;
89-
90-
GridFn::new(x, y).map_err(|e| de::Error::custom(format!("Failed to create GridFn: {e}")))
91-
}
53+
struct GridFnHelper<T> {
54+
x: Vec<T>,
55+
y: Vec<T>,
9256
}
9357

94-
const FIELDS: &[&str] = &["x", "y"];
95-
deserializer.deserialize_struct("GridFn", FIELDS, GridFnVisitor)
58+
let helper = GridFnHelper::<T>::deserialize(deserializer)?;
59+
let x_array = Array1::from_vec(helper.x);
60+
let y_array = Array1::from_vec(helper.y);
61+
GridFn::new(x_array, y_array).map_err(serde::de::Error::custom)
9662
}
9763
}
9864

packages/treetime-convolution/src/lib.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,9 @@ pub mod grid_fn;
33
pub mod testing;
44

55
use num_traits::{Num, NumCast};
6-
use serde::{Deserialize, Serialize};
76
use std::fmt::Debug;
87

9-
pub trait InterpElem: Num + NumCast + Debug + Send + PartialOrd + Copy + Serialize + for<'de> Deserialize<'de> {}
8+
pub trait InterpElem: Num + NumCast + Debug + Send + PartialOrd + Copy {}
109

1110
impl InterpElem for f64 {}
1211

packages/treetime/src/commands/timetree/inference/runner.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,13 +159,21 @@ mod tests {
159159
use ordered_float::OrderedFloat;
160160
use pretty_assertions::assert_eq;
161161
use std::collections::BTreeMap;
162+
use std::path::Path;
163+
use treetime_io::json::{JsonPretty, json_write_file};
162164

163165
#[test]
164166
fn test_timetree_flu_h3n2_poisson() -> Result<(), Report> {
165167
const CLOCK_RATE_MU: f64 = 0.0028;
166168
const SEQUENCE_LENGTH_L: usize = 1400;
167169
const GRID_SIZE: usize = 200;
168170

171+
fn dump_graph(graph: &GraphAncestral, filename: &str) -> Result<(), Report> {
172+
let OUTPUT_DIR = Path::new("../../tmp/test_timetree_flu_h3n2_poisson");
173+
json_write_file(OUTPUT_DIR.join(filename), graph, JsonPretty(true))?;
174+
Ok(())
175+
}
176+
169177
// Rerooted H3N2 flu tree obtained from:
170178
//
171179
// cargo run --bin=treetime -- clock \
@@ -261,10 +269,13 @@ mod tests {
261269
load_date_constraints(&input_dates, &graph)?;
262270

263271
create_poisson_branch_distributions(&graph, CLOCK_RATE_MU, SEQUENCE_LENGTH_L, GRID_SIZE)?;
272+
dump_graph(&graph, "001_after_create_poisson_branch_distributions.json")?;
264273

265274
propagate_distributions_backward(&graph)?;
275+
dump_graph(&graph, "002_after_propagate_distributions_backward.json")?;
266276

267277
propagate_distributions_forward(&graph)?;
278+
dump_graph(&graph, "003_after_propagate_distributions_forward.json")?;
268279

269280
let actual: BTreeMap<String, f64> = graph
270281
.get_nodes()

packages/treetime/src/distribution/distribution.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@ pub enum Distribution {
1717
Empty,
1818
Point(DistributionPoint<f64>),
1919
Range(DistributionRange<f64>),
20-
21-
#[serde(skip)]
2220
Function(DistributionFunction<f64>),
2321
}
2422

packages/treetime/src/distribution/distribution_function.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
use eyre::Report;
22
use ndarray::Array1;
33
use ndarray_stats::QuantileExt;
4+
use serde::{Deserialize, Serialize};
45
use treetime_convolution::{GridFn, InterpElem};
56

6-
#[derive(Clone, Debug, PartialEq, Eq)]
7+
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
78
pub struct DistributionFunction<T: InterpElem> {
89
grid_fn: GridFn<T>,
910
}

0 commit comments

Comments
 (0)