Skip to content
5 changes: 5 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,14 @@

* filter: Improved speed of using `--group-by month` on large datasets. [#1845][] (@victorlin)

### Internal changes

* Improved type annotations for better type checking. [#1860][] (@corneliusroemer)

[#1819]: https://github.com/nextstrain/augur/pull/1819
[#1844]: https://github.com/nextstrain/augur/pull/1844
[#1845]: https://github.com/nextstrain/augur/pull/1845
[#1860]: https://github.com/nextstrain/augur/pull/1860

## 31.3.0 (3 July 2025)

Expand Down
4 changes: 2 additions & 2 deletions augur/filter/include_exclude_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@

try:
# pandas ≥1.5.0 only
PandasUndefinedVariableError = pd.errors.UndefinedVariableError # type: ignore[attr-defined]
PandasUndefinedVariableError = pd.errors.UndefinedVariableError # type: ignore[unused-ignore]
except AttributeError:
PandasUndefinedVariableError = pd.core.computation.ops.UndefinedVariableError # type: ignore[attr-defined, misc]
PandasUndefinedVariableError = pd.core.computation.ops.UndefinedVariableError # type: ignore[misc]


# The strains to keep as a result of applying a filter function.
Expand Down
36 changes: 21 additions & 15 deletions augur/io/sequences.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
import Bio.SeqIO
import os

from augur.errors import AugurError
from augur.utils import augur
from importlib.metadata import version as installed_version
from packaging.version import Version
from shlex import quote as shquote
from shutil import which
from tempfile import NamedTemporaryFile
from textwrap import dedent
from typing import Iterator, Iterable, Union
from typing import Iterable, Iterator, Optional, Union

import Bio
from Bio.SeqFeature import SeqFeature, FeatureLocation
from Bio.SeqRecord import SeqRecord
from packaging.version import Version

from augur.errors import AugurError
from augur.utils import augur
Comment on lines -1 to +15
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not opposed to this import re-ordering, but I don't really understand the scheme being used — in particular, why is from packaging.version import Version located where it is?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is isort, it uses 3 groups: things that ship with Python, things that are dependencies, module itself. Packaging doesn't ship, is a dependency and sorts after B.


from .file import open_file
from .print import _n, indented_list
from .shell_command_runner import run_shell_command
Expand Down Expand Up @@ -186,12 +190,12 @@ def subset_fasta(input_filename: str, output_filename: str, ids_file: str, nthre
if os.path.isfile(output_filename):
# Remove the partial output file.
os.remove(output_filename)
raise AugurError(f"Sequence output failed, see error(s) above.")
raise AugurError("Sequence output failed, see error(s) above.")
else:
raise AugurError(f"Sequence output failed, see error(s) above. The command may have already written data to stdout. You may want to clean up any partial outputs.")
raise AugurError("Sequence output failed, see error(s) above. The command may have already written data to stdout. You may want to clean up any partial outputs.")


def load_features(reference, feature_names=None):
def load_features(reference: str, feature_names: Optional[Union[set[str], list[str]]] = None) -> dict:
"""
Parse a GFF/GenBank reference file. See the docstrings for _read_gff and
_read_genbank for details.
Expand Down Expand Up @@ -254,7 +258,6 @@ def _read_nuc_annotation_from_gff(record, reference):
if len(sequence_regions)>1:
raise AugurError(f"Reference {reference!r} contains multiple ##sequence-region pragma lines. Augur can only handle GFF files with a single one.")
elif sequence_regions:
from Bio.SeqFeature import SeqFeature, FeatureLocation
(name, start, stop) = sequence_regions[0]
nuc['pragma'] = SeqFeature(
FeatureLocation(start, stop, strand=1),
Expand Down Expand Up @@ -285,7 +288,7 @@ def _read_nuc_annotation_from_gff(record, reference):
raise AugurError(f"Reference {reference!r} didn't define any information we can use to create the 'nuc' annotation. You can use a line with a 'record' or 'source' GFF type or a ##sequence-region pragma.")


def _read_gff(reference, feature_names):
def _read_gff(reference: str, feature_names: Optional[Union[set[str], list[str]]] = None) -> dict[str, SeqFeature]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Second time seeing Optional[Union[set[str], list[str]]]; is this worth a

StringSetOrList = Optional[Union[set[str], list[str]]]

somewhere?

"""
Read a GFF file. We only read GFF IDs 'gene' or 'source' (the latter may not technically
be a valid GFF field, but is used widely within the Nextstrain ecosystem).
Expand Down Expand Up @@ -325,6 +328,7 @@ def _read_gff(reference, feature_names):
# TODO: Remove warning suppression after it's addressed upstream:
# <https://github.com/chapmanb/bcbb/issues/143>
import warnings

from Bio import BiopythonDeprecationWarning
warnings.simplefilter("ignore", BiopythonDeprecationWarning)
gff_entries = list(GFF.parse(in_handle, limit_info={'gff_type': valid_types}))
Expand Down Expand Up @@ -377,7 +381,7 @@ def _read_gff(reference, feature_names):
return features


def _read_nuc_annotation_from_genbank(record, reference):
def _read_nuc_annotation_from_genbank(record: SeqRecord, reference: str):
"""
Extracts the mandatory 'source' feature. If the sequence is present we check
the length agrees with the source. (The 'ORIGIN' may be left blank,
Expand All @@ -387,7 +391,8 @@ def _read_nuc_annotation_from_genbank(record, reference):

Parameters
----------
record : :py:class:`Bio.SeqRecord.SeqRecord` reference: string
record : :py:class:`Bio.SeqRecord.SeqRecord`
reference : string

Returns
-------
Expand All @@ -411,7 +416,7 @@ def _read_nuc_annotation_from_genbank(record, reference):
return nuc


def _read_genbank(reference, feature_names):
def _read_genbank(reference: str, feature_names: Optional[Union[set[str], list[str]]] = None) -> dict[str, SeqFeature]:
"""
Read a GenBank file. We only read GenBank feature keys 'CDS' or 'source'.
We create a "feature name" via:
Expand Down Expand Up @@ -443,7 +448,8 @@ def _read_genbank(reference, feature_names):
}

features_skipped = 0
for feat in gb.features:
for feature in gb.features:
feat = feature
Comment on lines -446 to +452
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say if you're going to rename the loop variable, rename all the usages too, don't immediately re-assign it back to the former name.

if feat.type=='CDS':
fname = None
if "locus_tag" in feat.qualifiers:
Expand Down
2 changes: 1 addition & 1 deletion augur/io/shell_command_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from signal import SIGKILL
except ImportError:
# A non-POSIX platform
SIGKILL = None # type: ignore[assignment]
SIGKILL = None


def run_shell_command(cmd, raise_errors=False, extra_env=None):
Expand Down
4 changes: 2 additions & 2 deletions augur/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,15 +499,15 @@ def pairs(xs: Iterable[str]) -> Iterable[Tuple[str, str]]:
>>> pairs(["abc=123=xyz", "=v=v"])
[('abc', '123=xyz'), ('', 'v=v')]
"""
return [tuple(x.split("=", 1)) if "=" in x else ("", x) for x in xs] # type: ignore
return [tuple(x.split("=", 1)) if "=" in x else ("", x) for x in xs] # type: ignore[misc]


def count_unique(xs: Iterable[T]) -> Iterable[Tuple[T, int]]:
# Using reduce() with a dict because it preserves input order, unlike
# itertools.groupby(), which requires a sort. Preserving order is a nice
# property for the user since we generate an error message with this.
# -trs, 24 July 2024
yield from reduce(lambda counts, x: {**counts, x: counts.get(x, 0) + 1}, xs, counts := {}).items() # type: ignore
yield from reduce(lambda counts, x: {**counts, x: counts.get(x, 0) + 1}, xs, {}).items() # type: ignore[arg-type,dict-item,return-value, call-overload]


def shquote_humanized(x):
Expand Down
6 changes: 4 additions & 2 deletions augur/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ def fix_dates(d: str, dayfirst: bool = True) -> str:
try:
try:
# pandas <= 2.1
from pandas.core.tools.datetimes import parsing # type: ignore[attr-defined, import-not-found]
from pandas.core.tools.datetimes import ( # type: ignore[attr-defined]
parsing,
)
except ImportError:
# pandas >= 2.2
from pandas._libs.tslibs import parsing
Expand All @@ -45,7 +47,7 @@ def fix_dates(d: str, dayfirst: bool = True) -> str:
results = parsing.parse_datetime_string_with_reso(d, dayfirst=dayfirst)
except AttributeError:
# pandas 1.x
results = parsing.parse_time_string(d, dayfirst=dayfirst) # type: ignore[attr-defined]
results = parsing.parse_time_string(d, dayfirst=dayfirst) # type: ignore[attr-defined, unused-ignore]
if len(results) == 2:
dto, res = results
else:
Expand Down
13 changes: 7 additions & 6 deletions augur/util_support/color_parser.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
from collections import defaultdict
import functools
from collections import defaultdict
from typing import Dict, List, Optional, TextIO, Tuple

from augur.data import as_file
from augur.io.file import open_file
from augur.util_support.color_parser_line import ColorParserLine


class ColorParser:
def __init__(self, *, mapping_filename, use_defaults=True):
def __init__(self, *, mapping_filename: Optional[str], use_defaults: bool = True) -> None:
self.mapping_filename = mapping_filename
self.use_defaults = use_defaults

@property
@functools.lru_cache()
def mapping(self):
colors = {}
def mapping(self) -> Dict[str, List[Tuple[str, str]]]:
colors: Dict[str, List[Tuple[str, str]]] = {}

if self.use_defaults:
with as_file("colors.tsv") as file:
Expand All @@ -27,8 +28,8 @@ def mapping(self):

return colors

def parse_file(self, file):
file_mapping = defaultdict(list)
def parse_file(self, file: TextIO) -> Dict[str, List[Tuple[str, str]]]:
file_mapping: Dict[str, List[Tuple[str, str]]] = defaultdict(list)
for pair in [ColorParserLine(line).pair() for line in file]:
if pair is None:
continue
Expand Down
17 changes: 9 additions & 8 deletions augur/util_support/node_data_reader.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import Bio.Phylo
import sys
from typing import Optional, Union

import Bio.Phylo

from augur.types import ValidationMode
from augur.util_support.node_data import DuplicatedNonDictAttributeError
from augur.util_support.node_data import NodeData
from augur.util_support.node_data import DuplicatedNonDictAttributeError, NodeData
from augur.util_support.node_data_file import NodeDataFile


Expand All @@ -19,21 +20,21 @@ class NodeDataReader:
If validation_mode is set to :py:attr:`augur.types.ValidationMode.SKIP` no validation is performed.
"""

def __init__(self, filenames, tree_file=None, validation_mode=ValidationMode.ERROR):
def __init__(self, filenames: Union[str, list[str]], tree_file: Optional[str] =None, validation_mode=ValidationMode.ERROR) -> None:
if not isinstance(filenames, list):
filenames = [filenames]
self.filenames = filenames
self.tree_file = tree_file
self.validation_mode = validation_mode

def read(self):
def read(self) -> dict:
node_data = self.build_node_data()

self.check_against_tree_file(node_data)

return node_data

def build_node_data(self):
def build_node_data(self) -> dict:
node_data = NodeData()

for node_data_file in self.node_data_files:
Expand All @@ -54,7 +55,7 @@ def build_node_data(self):
def node_data_files(self):
return (NodeDataFile(fname, validation_mode = self.validation_mode) for fname in self.filenames)

def check_against_tree_file(self, node_data):
def check_against_tree_file(self, node_data: dict) -> None:
if not self.tree_file:
return

Expand All @@ -67,7 +68,7 @@ def check_against_tree_file(self, node_data):
sys.exit(2)

@property
def node_names_from_tree_file(self):
def node_names_from_tree_file(self) -> set[str]:
try:
tree = Bio.Phylo.read(self.tree_file, "newick")
except Exception as e:
Expand Down
Loading
Loading