Skip to content

Commit f7b1169

Browse files
committed
Implement subsampling via a script
This implements a new script to encapsulate the subsampling logic formerly encoded in snakemake rules. This is in preparation for moving this script to the augur repo where it will become `augur subsample`. (We have chosen to develop this in the ncov repo for simplicity.) The script currently uses the same approach as the former snakemake rules, however python functions are called rather than scripts / augur commands. Briefly, the steps are: 1. A subsampling scheme is provided, parsed, validated, and turned into a simple graph to indicate which samples rely on other samples having been computed (i.e. which are needed for priorities) 2. Each sample is computed by calling the run function of augur filter 3. If priorities need to be calculated for a sample to be computed, this is achieved by calling functions from the two existing scripts. 4. The set of sequences to include in each sample is combined, and outputs written.
1 parent 946d3d2 commit f7b1169

File tree

3 files changed

+497
-207
lines changed

3 files changed

+497
-207
lines changed

scripts/subsample.py

Lines changed: 351 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,351 @@
1+
from augur.utils import AugurException
2+
from augur.filter import run as augur_filter, register_arguments as register_filter_arguments
3+
from augur.index import index_sequences
4+
from augur.io import write_sequences, open_file, read_sequences, read_metadata
5+
from get_distance_to_focal_set import get_distance_to_focal_set # eventually from augur.priorities (or similar)
6+
from priorities import create_priorities # eventually from augur.priorities (or similar)
7+
import yaml
8+
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
9+
from os import path
10+
import pandas as pd
11+
from tempfile import NamedTemporaryFile
12+
import jsonschema
13+
# from pkg_resources import resource_string
14+
15+
DESCRIPTION = "Subsample sequences based on user-defined YAML configuration"
16+
17+
def register_arguments(parser):
18+
parser.add_argument('--scheme', required=True, metavar="YAML", help="subsampling scheme")
19+
parser.add_argument('--output-dir', required=True, metavar="PATH", help="directory to save intermediate results")
20+
parser.add_argument('--metadata', required=True, metavar="TSV", help="metadata")
21+
parser.add_argument('--alignment', required=True, metavar="FASTA", help="alignment to subsample")
22+
parser.add_argument('--alignment-index', required=False, metavar="INDEX", help="sequence index of alignment")
23+
parser.add_argument('--reference', required=True, metavar="FASTA", help="reference (which was used for alignment)")
24+
parser.add_argument('--include-strains-file', required=False, nargs="+", default=None, metavar="TXT", help="strains to force include")
25+
parser.add_argument('--exclude-strains-file', required=False, nargs="+", default=None, metavar="TXT", help="strains to force exclude")
26+
parser.add_argument('--output-fasta', required=True, metavar="FASTA", help="output subsampled sequences")
27+
parser.add_argument('--output-metadata', required=True, metavar="TSV", help="output subsampled metadata")
28+
parser.add_argument('--output-log', required=False, metavar="TSV", help="log file explaining why strains were excluded / included")
29+
parser.add_argument('--use-existing-outputs', required=False, action="store_true", help="use intermediate files, if they exist")
30+
31+
def run(args):
32+
33+
config = parse_scheme(args.scheme)
34+
35+
generate_sequence_index(args)
36+
37+
samples = [Sample(name, data, args) for name, data in config.items()]
38+
39+
graph = make_graph(samples)
40+
41+
traverse_graph(
42+
graph,
43+
lambda s: s.filter()
44+
)
45+
46+
combine_samples(args, samples)
47+
48+
def parse_scheme(filename):
49+
with open(filename) as fh:
50+
try:
51+
data = yaml.safe_load(fh)
52+
except yaml.YAMLError as exc:
53+
print(exc)
54+
raise AugurException(f"Error parsing subsampling scheme {filename}")
55+
validate_scheme(data)
56+
return data
57+
58+
59+
def validate_scheme(scheme):
60+
try:
61+
# When we move this to `augur subsample`, load the schema via:
62+
# schema = yaml.safe_load(resource_string(__package__, path.join("data", "schema-subsampling.yaml")))
63+
with open(path.join(path.dirname(path.realpath(__file__)), "subsample_schema.yaml")) as fh:
64+
schema = yaml.safe_load(fh)
65+
except yaml.YAMLError as err:
66+
raise AugurException("Subsampling schema definition is not valid YAML. Error: {}".format(err))
67+
# check loaded schema is itself valid -- see http://python-jsonschema.readthedocs.io/en/latest/errors/
68+
try:
69+
jsonschema.Draft6Validator.check_schema(schema)
70+
except jsonschema.exceptions.SchemaError as err:
71+
raise AugurException("Subsampling schema definition is not valid. Error: {}".format(path, err))
72+
73+
try:
74+
jsonschema.Draft6Validator(schema).validate(scheme)
75+
except jsonschema.exceptions.ValidationError as err:
76+
print(err)
77+
raise AugurException("Subsampling scheme failed validation")
78+
79+
class Sample():
80+
"""
81+
A class to hold information about a sample. A subsampling scheme will consist of multiple
82+
samples. Each sample may depend on the priorities based off another sample.
83+
"""
84+
def __init__(self, name, config, cmd_args):
85+
self.name = name
86+
self.tmp_dir = cmd_args.output_dir
87+
self.alignment = cmd_args.alignment
88+
self.alignment_index = cmd_args.alignment_index
89+
self.reference = cmd_args.reference
90+
self.metadata = cmd_args.metadata
91+
self.initialise_filter_args(config, cmd_args)
92+
self.priorities = config.get("priorities", None)
93+
self.use_existing_outputs = args.use_existing_outputs
94+
95+
def initialise_filter_args(self, config, subsample_args):
96+
"""
97+
Currently this method is needed as we need to call `augur filter`'s `run()` with an
98+
argparse instance. An improvement here would be to expose appropriate filtering
99+
functions and call them as needed, with the output being returned rather than
100+
written to disk.
101+
"""
102+
# create the appropriate command-line arguments for the augur filter run we want
103+
arg_list = [
104+
"--metadata", self.metadata,
105+
"--sequences", self.alignment,
106+
"--sequence-index", self.alignment_index,
107+
"--output", path.join(self.tmp_dir, f"sample.{self.name}.fasta"), # filtered sequences in FASTA forma
108+
"--output-metadata", path.join(self.tmp_dir, f"sample.{self.name}.tsv"), # metadata for strains that passed filters
109+
"--output-strains", path.join(self.tmp_dir, f"sample.{self.name}.txt"), # list of strains that passed filters (no header)
110+
"--output-log", path.join(self.tmp_dir, f"sample.{self.name}.log.tsv")
111+
]
112+
# convert the YAML config into the command-line arguments for augur filter
113+
for name, value in config.items():
114+
if isinstance(value, dict):
115+
pass # we explicitly ignore dictionary config entries
116+
elif isinstance(value, list):
117+
arg_list.append(f"--{name}")
118+
arg_list.extend([str(v) for v in value])
119+
elif isinstance(value, bool):
120+
if value:
121+
arg_list.append(f"--{name}")
122+
else:
123+
arg_list.append(f"--{name}")
124+
arg_list.append(str(value))
125+
# mock an ArgumentParser so that we can use augur filters interface, avoiding the need to duplicate logic
126+
parser = ArgumentParser(prog="Mock_Augur_Filter")
127+
register_filter_arguments(parser)
128+
self.filter_args, unused_args = parser.parse_known_args(arg_list)
129+
if unused_args:
130+
print(f"Warning - the following config parameters are not part of augur filter and may be ignored:")
131+
print(' '.join(unused_args))
132+
133+
def calculate_required_priorities(self):
134+
"""
135+
If computation of this sample requires priority information of another sample
136+
(the "focus"), then this function will compute those priorities by calling
137+
a method on the focal sample object.
138+
"""
139+
if not self.priorities:
140+
return
141+
focal_sample = self.priorities.get('sample', None)
142+
if not focal_sample:
143+
raise AugurException(f"Cannot calculate priorities needed for {self.name} as the {self.get_priority_focus_name()} sample wasn't linked")
144+
print(f"Calculating priorities of {focal_sample.name}, as required by {self.name}")
145+
priorities_file = focal_sample.calculate_priorities()
146+
print(f"\tSetting {self.name} filter priority file to {priorities_file}")
147+
self.filter_args.priority = priorities_file
148+
149+
def calculate_priorities(self):
150+
"""
151+
Calculate the priorities TSV file for samples in the alignment vs this sample
152+
153+
Returns the filename of the priorities file (TSV)
154+
"""
155+
156+
proximity_output_file = path.join(self.tmp_dir, f"proximity_{self.name}.tsv")
157+
if self.use_existing_outputs and check_outputs_exist(proximity_output_file):
158+
print(f"Using existing proximity scores for {self.name}")
159+
else:
160+
print(f"Calculating proximity of {self.name}")
161+
get_distance_to_focal_set(
162+
self.alignment,
163+
self.reference,
164+
self.filter_args.output,
165+
proximity_output_file,
166+
ignore_seqs=["Wuhan/Hu-1/2019"] # TODO - use the config to define this?
167+
)
168+
169+
priorities_path = path.join(self.tmp_dir, f"priorities_{self.name}.tsv")
170+
if self.use_existing_outputs and check_outputs_exist(priorities_path):
171+
print(f"Using existing priorities for {self.name}")
172+
else:
173+
print(f"Calculating priorities of {self.name}")
174+
create_priorities(
175+
self.alignment_index,
176+
proximity_output_file,
177+
priorities_path
178+
)
179+
return priorities_path
180+
181+
def get_priority_focus_name(self):
182+
if not self.priorities:
183+
return None
184+
return self.priorities['focus']
185+
186+
def set_priority_sample(self, sample):
187+
if not self.priorities:
188+
raise AugurException(f"No priorities set for {self.name}")
189+
self.priorities['sample'] = sample
190+
191+
def filter(self):
192+
print("\n---------------------------------\nCONSTRUCTING SAMPLE FOR", self.name, "\n---------------------------------")
193+
self.calculate_required_priorities()
194+
if self.use_existing_outputs and check_outputs_exist(self.filter_args.output_metadata, self.filter_args.output_strains, self.filter_args.output_log):
195+
print(f"Using existing filtering results for {self.name}")
196+
else:
197+
print("Calling augur filter")
198+
print("Filter arguments:")
199+
for k,v in self.filter_args.__dict__.items():
200+
if v is not None:
201+
print(f"\t{k: <30}{v}")
202+
augur_filter(self.filter_args)
203+
204+
# In the future, instead of `augur_filter` saving data to disk, it would return
205+
# data to the calling process. In lieu of that, we read the data just written.
206+
try:
207+
self.sampled_strains = set(pd.read_csv(self.filter_args.output_strains, header=None)[0])
208+
except pd.errors.EmptyDataError:
209+
self.sampled_strains = set()
210+
self.filter_log = pd.read_csv(
211+
self.filter_args.output_log,
212+
header=0,
213+
sep="\t",
214+
index_col="strain"
215+
)
216+
217+
218+
def make_graph(samples):
219+
""""
220+
Given a config file, construct a graph of samples to perform in an iterative fashion, such that
221+
priorities
222+
This is a DAG, however an extremely simple one which we can construct outselves rather than relying on
223+
extra libraries.
224+
Constraints:
225+
* Each sample can only use priorities of one other sample
226+
* Acyclic
227+
Structure:
228+
tuple: (sample name, list of descendent samples) where a "descendant" sample requires the linked sample to be
229+
created prior to it's creation. Each entry in the list has this tuple structure.
230+
"""
231+
232+
included = set() # set of samples added to graph
233+
graph = (None, [])
234+
235+
# add all the samples which don't require priorities to the graph
236+
for sample in samples:
237+
if not sample.get_priority_focus_name():
238+
graph[1].append((sample, []))
239+
included.add(sample.name)
240+
241+
def add_descendants(level):
242+
parent_sample = level[0]
243+
descendants = level[1]
244+
for sample in samples:
245+
if sample.name in included:
246+
continue
247+
if sample.get_priority_focus_name() == parent_sample.name:
248+
sample.set_priority_sample(parent_sample)
249+
descendants.append((sample, []))
250+
included.add(sample.name)
251+
for inner_level in descendants:
252+
add_descendants(inner_level)
253+
254+
for level in graph[1]:
255+
add_descendants(level)
256+
257+
# from pprint import pprint
258+
# print("\ngraph"); pprint(graph);print("\n")
259+
260+
if len(samples)!=len(included):
261+
AugurException("Incomplete graph construction")
262+
263+
return graph
264+
265+
def traverse_graph(level, callback):
266+
this_sample, descendents = level
267+
if this_sample:
268+
callback(this_sample)
269+
for child in descendents:
270+
traverse_graph(child, callback)
271+
272+
def generate_sequence_index(args):
273+
if args.alignment_index:
274+
print("Skipping sequence index creation as an index was provided")
275+
return
276+
print("Creating ephemeral sequence index file")
277+
with NamedTemporaryFile(delete=False) as sequence_index_file:
278+
sequence_index_path = sequence_index_file.name
279+
index_sequences(args.alignment, sequence_index_path)
280+
args.alignment_index = sequence_index_path
281+
282+
283+
def combine_samples(args, samples):
284+
"""Collect the union of strains which are included in each sample and write them to disk.
285+
Parameters
286+
----------
287+
args : argparse.Namespace
288+
Parsed arguments from argparse
289+
samples : list[Sample]
290+
list of samples
291+
"""
292+
print("\n\n")
293+
### Form a union of each sample set, which is the subsampled strains list
294+
sampled_strains = set()
295+
for sample in samples:
296+
print(f"Sample \"{sample.name}\" included {len(sample.sampled_strains)} strains")
297+
sampled_strains.update(sample.sampled_strains)
298+
print(f"In total, {len(sampled_strains)} strains are included in the resulting subsampled dataset")
299+
300+
## Iterate through the input sequences, streaming a subsampled version to disk.
301+
sequences = read_sequences(args.alignment)
302+
sequences_written_to_disk = 0
303+
with open_file(args.output_fasta, "wt") as output_handle:
304+
for sequence in sequences:
305+
if sequence.id in sampled_strains:
306+
sequences_written_to_disk += 1
307+
write_sequences(sequence, output_handle, 'fasta')
308+
print(f"{sequences_written_to_disk} sequences written to {args.output_fasta}")
309+
310+
## Iterate through the metadata in chunks, writing out those entries which are in the subsample
311+
metadata_reader = read_metadata(
312+
args.metadata,
313+
id_columns=["strain", "name"], # TODO - this should be an argument
314+
chunk_size=10000 # TODO - argument
315+
)
316+
metadata_header = True
317+
metadata_mode = "w"
318+
metadata_written_to_disk = 0
319+
for metadata in metadata_reader:
320+
df = metadata.loc[metadata.index.intersection(sampled_strains)]
321+
df.to_csv(
322+
args.output_metadata,
323+
sep="\t",
324+
header=metadata_header,
325+
mode=metadata_mode,
326+
)
327+
metadata_written_to_disk += df.shape[0]
328+
metadata_header = False
329+
metadata_mode = "a"
330+
print(f"{metadata_written_to_disk} metadata entries written to {args.output_metadata}")
331+
332+
## Combine the log files (from augur filter) for each sample into a larger log file
333+
## Format TBD
334+
## TODO
335+
336+
def check_outputs_exist(*paths):
337+
for p in paths:
338+
if not (path.exists(p) and path.isfile(p)):
339+
return False
340+
return True
341+
342+
if __name__ == "__main__":
343+
# the format of this block is designed specifically for future transfer of this script
344+
# into augur in the form of `augur subsample`
345+
parser = ArgumentParser(
346+
usage=DESCRIPTION,
347+
formatter_class=ArgumentDefaultsHelpFormatter,
348+
)
349+
register_arguments(parser)
350+
args = parser.parse_args()
351+
run(args)

scripts/subsample_schema.yaml

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
2+
type: object
3+
title: YAML Schema for subsampling configuration to be consumed by a subsample script / command
4+
patternProperties:
5+
"^[a-zA-Z0-9*_-]+$":
6+
type: object
7+
title: description of a sample
8+
additionalProperties: false
9+
properties:
10+
group-by:
11+
type: array
12+
minItems: 1
13+
items:
14+
type: string
15+
sequences-per-group:
16+
type: integer
17+
subsample-max-sequences:
18+
type: integer
19+
exclude-ambiguous-dates-by:
20+
type: string
21+
enum: ["any", "day", "month", "year"]
22+
min-date:
23+
type: ["number", "string"]
24+
pattern: ^\d{4}-\d{2}-\d{2}$
25+
max-date:
26+
type: ["number", "string"]
27+
pattern: ^\d{4}-\d{2}-\d{2}$
28+
exclude-where:
29+
type: array
30+
minItems: 1
31+
items:
32+
type: string
33+
include-where:
34+
type: array
35+
minItems: 1
36+
items:
37+
type: string
38+
query:
39+
type: string
40+
probabilistic-sampling:
41+
type: boolean
42+
no-probabilistic-sampling:
43+
type: boolean
44+
priorities:
45+
type: object

0 commit comments

Comments
 (0)