Skip to content

Commit 8d4746e

Browse files
Merge branch 'feat/timetree-scaffolding' into rust
2 parents 4cf5b01 + 53b838c commit 8d4746e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+1599
-52
lines changed

Cargo.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ authors = [
3030
]
3131

3232

33-
3433
[patch.crates-io]
3534
lapack-sys = { git = "https://github.com/numrs/lapack-sys", rev = "d3b1ca9" }
3635
lax = { git = "https://github.com/numrs/ndarray-linalg", rev = "ac7052e" }
@@ -319,6 +318,10 @@ iter_over_hash_type = "allow"
319318
needless_range_loop = "allow"
320319
non_std_lazy_statics = "allow" # bugs clap annotations
321320

321+
# Remove once amount of unimplemented code is reduced
322+
todo = "allow"
323+
let_underscore_untyped = "allow"
324+
322325
[workspace.lints.rustdoc]
323326
broken_intra_doc_links = "warn"
324327
bare_urls = "warn"

packages/treetime-cli/src/bin/treetime.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use treetime::commands::homoplasy::run_homoplasy::run_homoplasy;
77
use treetime::commands::mugration::run_mugration::run_mugration;
88
use treetime::commands::optimize::run::run_optimize;
99
use treetime::commands::prune::run::run_prune;
10-
use treetime::commands::timetree::run_timetree_estimation::run_timetree_estimation;
10+
use treetime::commands::timetree::run::run_timetree_estimation;
1111
use treetime::utils::global_init::global_init;
1212
use treetime::utils::openblas::get_openblas_info_str;
1313
use treetime_cli::cli::treetime_cli::{TreetimeCommands, generate_shell_completions, treetime_parse_cli_args};

packages/treetime-cli/src/cli/treetime_cli.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use treetime::commands::homoplasy::homoplasy_args::TreetimeHomoplasyArgs;
1515
use treetime::commands::mugration::mugration_args::TreetimeMugrationArgs;
1616
use treetime::commands::optimize::args::TreetimeOptimizeArgs;
1717
use treetime::commands::prune::args::TreetimePruneArgs;
18-
use treetime::commands::timetree::timetree_args::TreetimeTimetreeArgs;
18+
use treetime::commands::timetree::args::TreetimeTimetreeArgs;
1919
use treetime::utils::clap_styles::styles;
2020
use treetime::utils::global_init::setup_logger;
2121

packages/treetime/examples/convolution.rs

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ fn main() -> eyre::Result<()> {
141141
_ => {
142142
eprintln!("Unknown algorithm: {}", args.algorithm);
143143
return Ok(());
144-
}
144+
},
145145
};
146146

147147
// Compute analytical expected result
@@ -220,14 +220,7 @@ fn main() -> eyre::Result<()> {
220220
format!("{rel_err:.2}")
221221
};
222222

223-
println!(
224-
"{:>8.2} {:>12.6} {:>12.6} {:>12} {:>9}%",
225-
x_val,
226-
actual_val,
227-
expected_val,
228-
diff_str,
229-
rel_err_str
230-
);
223+
println!("{x_val:>8.2} {actual_val:>12.6} {expected_val:>12.6} {diff_str:>12} {rel_err_str:>9}%");
231224
}
232225

233226
// Create and save results structure
@@ -259,7 +252,11 @@ fn compute_domain_agreement_metrics(
259252
Ok(metrics)
260253
}
261254

262-
fn plot_input_functions(f: &treetime::distribution::reference::grid_fn::GridFn, g: &treetime::distribution::reference::grid_fn::GridFn, args: &Args) -> eyre::Result<()> {
255+
fn plot_input_functions(
256+
f: &treetime::distribution::reference::grid_fn::GridFn,
257+
g: &treetime::distribution::reference::grid_fn::GridFn,
258+
args: &Args,
259+
) -> eyre::Result<()> {
263260
let output_path = format!("{}/input_functions.svg", args.output_dir);
264261
let root = SVGBackend::new(&output_path, (800, 600)).into_drawing_area();
265262
root.fill(&WHITE)?;
@@ -307,7 +304,10 @@ fn plot_convolution_results(
307304
let max_val = actual.iter().chain(expected.iter()).fold(0.0_f64, |a, &b| a.max(b));
308305

309306
let mut chart = ChartBuilder::on(&root)
310-
.caption(&format!("Convolution Results: (f * g)(x) [{}]", args.algorithm), ("Arial", 24))
307+
.caption(
308+
format!("Convolution Results: (f * g)(x) [{}]", args.algorithm),
309+
("Arial", 24),
310+
)
311311
.margin(20)
312312
.x_label_area_size(40)
313313
.y_label_area_size(50)
@@ -318,7 +318,7 @@ fn plot_convolution_results(
318318
let actual_data: Vec<(f64, f64)> = x_grid.iter().zip(actual.iter()).map(|(&x, &y)| (x, y)).collect();
319319
chart
320320
.draw_series(LineSeries::new(actual_data, BLUE.stroke_width(2)))?
321-
.label(&format!("Actual ({})", args.algorithm))
321+
.label(format!("Actual ({})", args.algorithm))
322322
.legend(|(x, y)| PathElement::new(vec![(x, y), (x + 10, y)], BLUE));
323323

324324
let expected_data: Vec<(f64, f64)> = x_grid.iter().zip(expected.iter()).map(|(&x, &y)| (x, y)).collect();
@@ -353,7 +353,10 @@ fn plot_error_analysis(
353353
let max_error = absolute_errors.iter().fold(0.0_f64, |a, &b| a.max(b));
354354

355355
let mut chart = ChartBuilder::on(&root)
356-
.caption(&format!("Absolute Error: |Actual - Expected| [{}]", args.algorithm), ("Arial", 24))
356+
.caption(
357+
format!("Absolute Error: |Actual - Expected| [{}]", args.algorithm),
358+
("Arial", 24),
359+
)
357360
.margin(20)
358361
.x_label_area_size(40)
359362
.y_label_area_size(50)

packages/treetime/src/commands/ancestral/marginal_unified.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use parking_lot::RwLock;
1212
use std::sync::Arc;
1313

1414
/// Main entry point for marginal reconstruction
15-
pub fn run_marginal<P: PartitionMarginalOps + HasLogLh>(
15+
pub fn run_marginal<P: PartitionMarginalOps + HasLogLh + ?Sized>(
1616
graph: &GraphAncestral,
1717
partitions: &[Arc<RwLock<P>>],
1818
aln: Option<&[FastaRecord]>,
@@ -32,7 +32,7 @@ pub fn run_marginal<P: PartitionMarginalOps + HasLogLh>(
3232
}
3333

3434
/// Ancestral sequence reconstruction
35-
pub fn ancestral_reconstruction_marginal<P: PartitionMarginalOps + HasLogLh>(
35+
pub fn ancestral_reconstruction_marginal<P: PartitionMarginalOps + HasLogLh + ?Sized>(
3636
graph: &GraphAncestral,
3737
include_leaves: bool,
3838
partitions: &[Arc<RwLock<P>>],
@@ -59,7 +59,7 @@ pub fn ancestral_reconstruction_marginal<P: PartitionMarginalOps + HasLogLh>(
5959
}
6060

6161
/// Backward pass: calculates ingroup profiles
62-
fn marginal_backward<P: PartitionMarginalOps + HasLogLh>(
62+
fn marginal_backward<P: PartitionMarginalOps + HasLogLh + ?Sized>(
6363
graph: &GraphAncestral,
6464
partitions: &[Arc<RwLock<P>>],
6565
) -> Result<(), Report> {
@@ -70,7 +70,7 @@ fn marginal_backward<P: PartitionMarginalOps + HasLogLh>(
7070
Ok(())
7171
}
7272

73-
fn run_marginal_backward<P: PartitionMarginalOps + HasLogLh>(
73+
fn run_marginal_backward<P: PartitionMarginalOps + HasLogLh + ?Sized>(
7474
partitions: &[Arc<RwLock<P>>],
7575
node: &GraphNodeBackward<NodeAncestral, EdgeAncestral, ()>,
7676
) -> Result<(), Report> {
@@ -82,7 +82,7 @@ fn run_marginal_backward<P: PartitionMarginalOps + HasLogLh>(
8282
}
8383

8484
/// Forward pass: calculates outgroup profiles
85-
fn marginal_forward<P: PartitionMarginalOps + HasLogLh>(
85+
fn marginal_forward<P: PartitionMarginalOps + HasLogLh + ?Sized>(
8686
graph: &GraphAncestral,
8787
partitions: &[Arc<RwLock<P>>],
8888
) -> Result<(), Report> {
@@ -93,7 +93,7 @@ fn marginal_forward<P: PartitionMarginalOps + HasLogLh>(
9393
Ok(())
9494
}
9595

96-
fn run_marginal_forward<P: PartitionMarginalOps + HasLogLh>(
96+
fn run_marginal_forward<P: PartitionMarginalOps + HasLogLh + ?Sized>(
9797
graph: &GraphAncestral,
9898
partitions: &[Arc<RwLock<P>>],
9999
node: &GraphNodeForward<NodeAncestral, EdgeAncestral, ()>,

packages/treetime/src/commands/clock/clock_args.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use crate::commands::clock::clock_regression::ClockOptions;
33
use crate::commands::clock::find_best_root::params::{
44
BrentParams, GoldenSectionParams, GridSearchParams, OptimizationMethod,
55
};
6-
use crate::commands::timetree::timetree_args::{BranchLengthMode, RerootMode};
6+
use crate::commands::timetree::args::{BranchLengthMode, RerootMode};
77
use crate::gtr::get_gtr::GtrModelName;
88
use clap::{Args, Parser, ValueHint};
99
use smart_default::SmartDefault;

packages/treetime/src/commands/optimize/optimize_dense.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ use crate::commands::optimize::optimize_unified::OptimizationMetrics;
2424
use crate::gtr::gtr::GTR;
2525
use crate::representation::graph_ancestral::GraphAncestral;
2626
use crate::representation::graph_dense::DenseSeqDis;
27+
use crate::representation::partition_marginal::PartitionMarginalOps;
2728
use crate::representation::partition_marginal_dense::PartitionMarginalDense;
2829
use eyre::Report;
2930
use ndarray::{Array2, Axis};
@@ -76,7 +77,10 @@ pub fn run_optimize_dense(
7677
graph: &GraphAncestral,
7778
partitions: &[Arc<RwLock<PartitionMarginalDense>>],
7879
) -> Result<(), Report> {
79-
let total_length: usize = partitions.iter().map(|part| part.read_arc().length).sum();
80+
let total_length: usize = partitions
81+
.iter()
82+
.map(|part| part.read_arc().get_sequence_length().unwrap_or(0))
83+
.sum();
8084
let one_mutation = 1.0 / total_length as f64;
8185
let n_partitions = partitions.len();
8286
graph.get_edges().iter().for_each(|edge_ref| {

packages/treetime/src/commands/optimize/optimize_sparse.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@
2222
//!
2323
use crate::commands::optimize::optimize_unified::OptimizationMetrics;
2424
use crate::graph::edge::GraphEdgeKey;
25-
use crate::representation::{graph_ancestral::GraphAncestral, partition_marginal_sparse::PartitionMarginalSparse};
25+
use crate::representation::graph_ancestral::GraphAncestral;
26+
use crate::representation::partition_marginal::PartitionMarginalOps;
27+
use crate::representation::partition_marginal_sparse::PartitionMarginalSparse;
2628
use crate::seq::mutation::Sub;
2729
use eyre::{OptionExt, Report};
2830
use itertools::Itertools;
@@ -141,7 +143,10 @@ pub fn run_optimize_sparse(
141143
graph: &GraphAncestral,
142144
partitions: &[Arc<RwLock<PartitionMarginalSparse>>],
143145
) -> Result<(), Report> {
144-
let total_length: usize = partitions.iter().map(|part| part.read_arc().length).sum();
146+
let total_length: usize = partitions
147+
.iter()
148+
.map(|part| part.read_arc().get_sequence_length().unwrap_or(0))
149+
.sum();
145150
let one_mutation = 1.0 / total_length as f64;
146151
let n_partitions = partitions.len();
147152
graph.get_edges().iter().try_for_each(|edge_ref| {

packages/treetime/src/commands/optimize/optimize_unified.rs

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use crate::commands::optimize::optimize_dense;
22
use crate::commands::optimize::optimize_sparse;
33
use crate::graph::edge::GraphEdgeKey;
44
use crate::representation::graph_ancestral::GraphAncestral;
5+
use crate::representation::partition_marginal::PartitionMarginalOps;
56
use crate::representation::partition_marginal_dense::PartitionMarginalDense;
67
use crate::representation::partition_marginal_sparse::PartitionMarginalSparse;
78
use eyre::Report;
@@ -220,8 +221,12 @@ pub fn run_optimize_mixed(
220221
) -> Result<(), Report> {
221222
let total_length: usize = dense_partitions
222223
.iter()
223-
.map(|part| part.read_arc().length)
224-
.chain(sparse_partitions.iter().map(|part| part.read_arc().length))
224+
.map(|part| part.read_arc().get_sequence_length().unwrap_or(0))
225+
.chain(
226+
sparse_partitions
227+
.iter()
228+
.map(|part| part.read_arc().get_sequence_length().unwrap_or(0)),
229+
)
225230
.sum();
226231

227232
let one_mutation = 1.0 / total_length as f64;
@@ -342,8 +347,12 @@ pub fn initial_guess_mixed(
342347

343348
let total_length: usize = dense_partitions
344349
.iter()
345-
.map(|part| part.read_arc().length)
346-
.chain(sparse_partitions.iter().map(|part| part.read_arc().length))
350+
.map(|part| part.read_arc().get_sequence_length().unwrap_or(0))
351+
.chain(
352+
sparse_partitions
353+
.iter()
354+
.map(|part| part.read_arc().get_sequence_length().unwrap_or(0)),
355+
)
347356
.sum();
348357

349358
let branch_length = (differences as f64) / (total_length as f64);

packages/treetime/src/commands/timetree/timetree_args.rs renamed to packages/treetime/src/commands/timetree/args.rs

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
use crate::alphabet::alphabet::AlphabetName;
12
use crate::commands::ancestral::anc_args::MethodAncestral;
23
use crate::gtr::get_gtr::GtrModelName;
34
use clap::{Parser, ValueEnum, ValueHint};
5+
use smart_default::SmartDefault;
46
use std::fmt::Debug;
57
use std::path::PathBuf;
68

@@ -49,7 +51,7 @@ impl Default for RerootMode {
4951
}
5052
}
5153

52-
#[derive(Parser, Debug)]
54+
#[derive(Parser, Debug, SmartDefault)]
5355
pub struct TreetimeTimetreeArgs {
5456
/// Path to one or multiple FASTA files with aligned input sequences
5557
///
@@ -130,6 +132,10 @@ pub struct TreetimeTimetreeArgs {
130132
#[clap(long)]
131133
pub keep_polytomies: bool,
132134

135+
/// Resolve polytomies using temporal information
136+
#[clap(long)]
137+
pub resolve_polytomies: bool,
138+
133139
/// use an autocorrelated molecular clock. Strength of the gaussian priors on branch specific rate
134140
/// deviation and the coupling of parent and offspring rates can be specified e.g. as --relax 1.0
135141
/// 0.5. Values around 1.0 correspond to weak priors, larger values constrain rate deviations more
@@ -139,8 +145,9 @@ pub struct TreetimeTimetreeArgs {
139145

140146
/// maximal number of iterations the inference cycle is run. Note that for polytomy resolution and
141147
/// coalescence models max_iter should be at least 2
142-
#[clap(long)]
143-
pub max_iter: Option<usize>,
148+
#[default = 2]
149+
#[clap(long, default_value_t = TreetimeTimetreeArgs::default().max_iter)]
150+
pub max_iter: usize,
144151

145152
/// coalescent time scale -- sensible values are on the order of the average hamming distance of
146153
/// contemporaneous sequences. In addition, 'opt' 'skyline' are valid options and estimate a
@@ -183,6 +190,10 @@ pub struct TreetimeTimetreeArgs {
183190
#[clap(long, default_value = "3.0")]
184191
pub clock_filter: f64,
185192

193+
/// Number of IQD (interquartile distance) for clock filter outlier detection
194+
#[clap(long)]
195+
pub n_iqd: Option<f64>,
196+
186197
/// Reroot the tree using root-to-tip regression. Valid choices are 'min_dev', 'least-squares',
187198
/// and 'oldest'. 'least-squares' adjusts the root to minimize residuals of the root-to-tip vs
188199
/// sampling time regression, 'min_dev' minimizes variance of root-to-tip distances. 'least-
@@ -207,6 +218,10 @@ pub struct TreetimeTimetreeArgs {
207218
#[clap(long)]
208219
pub covariation: bool,
209220

221+
/// Estimate timetree with rate variation to assess sensitivity to clock rate uncertainty
222+
#[clap(long)]
223+
pub vary_rate: bool,
224+
210225
/// GTR model to use
211226
///
212227
/// '--gtr infer' will infer a model from the data. Alternatively, specify the model type. If the specified model requires additional options, use '--gtr-params' to specify those.
@@ -225,6 +240,14 @@ pub struct TreetimeTimetreeArgs {
225240
#[clap(long, value_enum, default_value_t = MethodAncestral::default())]
226241
pub method_anc: MethodAncestral,
227242

243+
/// Alphabet to use for sequences
244+
#[clap(long, value_enum, default_value_t = AlphabetName::default())]
245+
pub alphabet: AlphabetName,
246+
247+
/// Use dense representation for sequences (store full probability distributions)
248+
#[clap(long)]
249+
pub dense: Option<bool>,
250+
228251
/// Use aminoacid alphabet
229252
#[clap(long)]
230253
pub aa: bool,
@@ -249,7 +272,18 @@ pub struct TreetimeTimetreeArgs {
249272
#[clap(long, short = 'O')]
250273
pub outdir: PathBuf,
251274

275+
/// Write iteration statistics to tracelog CSV file for monitoring convergence
276+
#[clap(long)]
277+
#[clap(value_hint = ValueHint::FilePath)]
278+
pub tracelog: Option<PathBuf>,
279+
252280
/// Random seed
253281
#[clap(long)]
254282
pub seed: Option<u64>,
255283
}
284+
285+
impl TreetimeTimetreeArgs {
286+
pub fn clock_filter_enabled(&self) -> bool {
287+
self.clock_filter > 0.0
288+
}
289+
}

0 commit comments

Comments
 (0)