1+ from augur .utils import AugurException
2+ from augur .filter import run as augur_filter
3+ from augur .index import index_sequences
4+ from augur .io import write_sequences , open_file , read_sequences , read_metadata
5+ from get_distance_to_focal_set import get_distance_to_focal_set # eventually from augur.priorities (or similar)
6+ from priorities import create_priorities # eventually from augur.priorities (or similar)
7+ import yaml
8+ from argparse import Namespace , ArgumentParser , ArgumentDefaultsHelpFormatter
9+ from os import path
10+ import pandas as pd
11+ from tempfile import NamedTemporaryFile
12+ import jsonschema
13+ # from pkg_resources import resource_string
14+
15+ DESCRIPTION = "Subsample sequences based on user-defined YAML configuration"
16+
17+ def register_arguments (parser ):
18+ parser .add_argument ('--scheme' , required = True , metavar = "YAML" , help = "subsampling scheme" )
19+ parser .add_argument ('--output-dir' , required = True , metavar = "PATH" , help = "directory to save intermediate results" )
20+ parser .add_argument ('--metadata' , required = True , metavar = "TSV" , help = "metadata" )
21+ parser .add_argument ('--alignment' , required = True , metavar = "FASTA" , help = "alignment to subsample" )
22+ parser .add_argument ('--alignment-index' , required = False , metavar = "INDEX" , help = "sequence index of alignment" )
23+ parser .add_argument ('--reference' , required = True , metavar = "FASTA" , help = "reference (which was used for alignment)" )
24+ parser .add_argument ('--include-strains-file' , required = False , nargs = "+" , default = None , metavar = "TXT" , help = "strains to force include" )
25+ parser .add_argument ('--exclude-strains-file' , required = False , nargs = "+" , default = None , metavar = "TXT" , help = "strains to force exclude" )
26+ parser .add_argument ('--output-fasta' , required = True , metavar = "FASTA" , help = "output subsampled sequences" )
27+ parser .add_argument ('--output-metadata' , required = True , metavar = "TSV" , help = "output subsampled metadata" )
28+ parser .add_argument ('--output-log' , required = False , metavar = "TSV" , help = "log file explaining why strains were excluded / included" )
29+
30+ def run (args ):
31+
32+ config = parse_scheme (args .scheme )
33+
34+ generate_sequence_index (args )
35+
36+ samples = [Sample (name , data , args ) for name , data in config .items ()]
37+
38+ graph = make_graph (samples )
39+
40+ traverse_graph (
41+ graph ,
42+ lambda s : s .filter ()
43+ )
44+
45+ combine_samples (args , samples )
46+
47+ def parse_scheme (filename ):
48+ with open (filename ) as fh :
49+ try :
50+ data = yaml .safe_load (fh )
51+ except yaml .YAMLError as exc :
52+ print (exc )
53+ raise AugurException (f"Error parsing subsampling scheme { filename } " )
54+ validate_scheme (data )
55+ return data
56+
57+
58+ def validate_scheme (scheme ):
59+ try :
60+ # When we move this to `augur subsample`, load the schema via:
61+ # schema = yaml.safe_load(resource_string(__package__, path.join("data", "schema-subsampling.yaml")))
62+ with open (path .join (path .dirname (path .realpath (__file__ )), "subsample_schema.yaml" )) as fh :
63+ schema = yaml .safe_load (fh )
64+ except yaml .YAMLError as err :
65+ raise AugurException ("Subsampling schema definition is not valid YAML. Error: {}" .format (err ))
66+ # check loaded schema is itself valid -- see http://python-jsonschema.readthedocs.io/en/latest/errors/
67+ try :
68+ jsonschema .Draft6Validator .check_schema (schema )
69+ except jsonschema .exceptions .SchemaError as err :
70+ raise AugurException ("Subsampling schema definition is not valid. Error: {}" .format (path , err ))
71+
72+ try :
73+ jsonschema .Draft6Validator (schema ).validate (scheme )
74+ except jsonschema .exceptions .ValidationError as err :
75+ print (err )
76+ raise AugurException ("Subsampling scheme failed validation" )
77+
78+ class Sample ():
79+ """
80+ A class to hold information about a sample. A subsampling scheme will consist of multiple
81+ samples. Each sample may depend on the priorities based off another sample.
82+ """
83+ def __init__ (self , name , config , cmd_args ):
84+ self .name = name
85+ self .tmp_dir = cmd_args .output_dir
86+ self .alignment = cmd_args .alignment
87+ self .alignment_index = cmd_args .alignment_index
88+ self .reference = cmd_args .reference
89+ self .metadata = cmd_args .metadata
90+ self .initialise_filter_args (config , cmd_args )
91+ self .priorities = config .get ("priorities" , None )
92+ print ("Constructor" , self .name )
93+
94+ def initialise_filter_args (self , config , subsample_args ):
95+ """
96+ Currently this method is needed as we need to call `augur filter`'s `run()` with an
97+ argparse instance. An improvement here would be to expose appropriate filtering
98+ functions and call them as needed, with the output being returned rather than
99+ written to disk.
100+ """
101+ args = Namespace ()
102+ args .metadata = self .metadata
103+ args .sequences = self .alignment
104+ args .sequence_index = self .alignment_index
105+ args .metadata_chunk_size = 100000
106+ args .metadata_id_columns = ["strain" , "name" ]
107+ args .min_date = None
108+ args .max_date = None
109+ args .group_by = config .get ('group-by' , None )
110+ args .sequences_per_group = config .get ("sequences-per-group" , None )
111+ args .subsample_max_sequences = config .get ("subsample-max-sequences" , None )
112+ args .exclude_ambiguous_dates_by = None
113+ args .exclude = subsample_args .exclude_strains_file
114+ args .exclude_all = None
115+ args .exclude_where = config .get ('exclude-where' , None )
116+ args .include = subsample_args .include_strains_file
117+ args .include_where = None
118+ args .min_length = None
119+ args .non_nucleotide = None
120+ args .probabilistic_sampling = config .get ("probabilistic-sampling" , None )
121+ args .no_probabilistic_sampling = config .get ("no-probabilistic-sampling" , None )
122+ args .priority = None
123+ args .subsample_seed = None
124+ args .query = config .get ("query" , None )
125+ args .output = path .join (self .tmp_dir , f"sample.{ self .name } .fasta" ) # filtered sequences in FASTA forma
126+ args .output_metadata = path .join (self .tmp_dir , f"sample.{ self .name } .tsv" ) # metadata for strains that passed filters
127+ args .output_strains = path .join (self .tmp_dir , f"sample.{ self .name } .txt" ) # list of strains that passed filters (no header)
128+ args .output_log = path .join (self .tmp_dir , f"sample.{ self .name } .log.tsv" )
129+ self .filter_args = args
130+
131+
132+ def calculate_required_priorities (self ):
133+ """
134+ If computation of this sample requires priority information of another sample
135+ (the "focus"), then this function will compute those priorities.
136+ """
137+ if not self .priorities :
138+ return
139+ focal_sample = self .priorities .get ('sample' , None )
140+ if not focal_sample :
141+ raise AugurException (f"Cannot calculate priorities needed for { self .name } as the { self .get_priority_focus_name ()} sample wasn't linked" )
142+ print (f"Calculating priorities required by { self .name } " )
143+ priorities_file = focal_sample .calculate_priorities ()
144+ self .filter_args .priority = priorities_file
145+
146+ def calculate_priorities (self ):
147+ """
148+ Calculate the priorities TSV file for the alignment in the context of this sample
149+
150+ Returns the filename of the priorities file (TSV)
151+ """
152+ print (f"Calculating proximity of { self .name } " )
153+ get_distance_to_focal_set (
154+ self .alignment ,
155+ self .reference ,
156+ self .filter_args .output ,
157+ path .join (self .tmp_dir , f"proximity_{ self .name } .tsv" ),
158+ ignore_seqs = ["Wuhan/Hu-1/2019" ] # TODO - use the config to define this?
159+ )
160+ print (f"Calculating priorities of { self .name } " )
161+ priorities_path = path .join (self .tmp_dir , f"priorities_{ self .name } .tsv" )
162+ create_priorities (
163+ self .alignment_index ,
164+ path .join (self .tmp_dir , f"proximity_{ self .name } .tsv" ),
165+ priorities_path
166+ )
167+ return priorities_path
168+
169+ def get_priority_focus_name (self ):
170+ if not self .priorities :
171+ return None
172+ return self .priorities ['focus' ]
173+
174+ def set_priority_sample (self , sample ):
175+ if not self .priorities :
176+ raise AugurException (f"No priorities set for { self .name } " )
177+ self .priorities ['sample' ] = sample
178+
179+ def filter (self ):
180+ print ("\n ---------------------------------\n CONSTRUCTING SAMPLE FOR" , self .name , "\n ---------------------------------" )
181+ self .calculate_required_priorities ()
182+ augur_filter (self .filter_args )
183+ # In the future, instead of `augur_filter` saving data to disk, it would return
184+ # data to the calling process. In lieu of that, we read the data just written.
185+ self .sampled_strains = set (pd .read_csv (self .filter_args .output_strains , header = None )[0 ])
186+ self .filter_log = pd .read_csv (
187+ self .filter_args .output_log ,
188+ header = 0 ,
189+ sep = "\t " ,
190+ index_col = "strain"
191+ )
192+
193+
194+ def make_graph (samples ):
195+ """"
196+ Given a config file, construct a graph of samples to perform in an iterative fashion, such that
197+ priorities
198+ This is a DAG, however an extremely simple one which we can construct outselves rather than relying on
199+ extra libraries.
200+ Constraints:
201+ * Each sample can only use priorities of one other sample
202+ * Acyclic
203+ Structure:
204+ tuple: (sample name, list of descendent samples) where a "descendant" sample requires the linked sample to be
205+ created prior to it's creation. Each entry in the list has this tuple structure.
206+ """
207+
208+ included = set () # set of samples added to graph
209+ graph = (None , [])
210+
211+ # add all the samples which don't require priorities to the graph
212+ for sample in samples :
213+ if not sample .get_priority_focus_name ():
214+ graph [1 ].append ((sample , []))
215+ included .add (sample .name )
216+
217+ def add_descendants (level ):
218+ parent_sample = level [0 ]
219+ descendants = level [1 ]
220+ for sample in samples :
221+ if sample .name in included :
222+ continue
223+ if sample .get_priority_focus_name () == parent_sample .name :
224+ sample .set_priority_sample (parent_sample )
225+ descendants .append ((sample , []))
226+ included .add (sample .name )
227+ for inner_level in descendants :
228+ add_descendants (inner_level )
229+
230+ for level in graph [1 ]:
231+ add_descendants (level )
232+
233+ # from pprint import pprint
234+ # print("\ngraph"); pprint(graph);print("\n")
235+
236+ if len (samples )!= len (included ):
237+ AugurException ("Incomplete graph construction" )
238+
239+ return graph
240+
241+ def traverse_graph (level , callback ):
242+ this_sample , descendents = level
243+ if this_sample :
244+ callback (this_sample )
245+ for child in descendents :
246+ traverse_graph (child , callback )
247+
248+ def generate_sequence_index (args ):
249+ if args .alignment_index :
250+ print ("Skipping sequence index creation as an index was provided" )
251+ return
252+ print ("Creating ephemeral sequence index file" )
253+ with NamedTemporaryFile (delete = False ) as sequence_index_file :
254+ sequence_index_path = sequence_index_file .name
255+ index_sequences (args .alignment , sequence_index_path )
256+ args .alignment_index = sequence_index_path
257+
258+
259+ def combine_samples (args , samples ):
260+ """Collect the union of strains which are included in each sample and write them to disk.
261+ Parameters
262+ ----------
263+ args : argparse.Namespace
264+ Parsed arguments from argparse
265+ samples : list[Sample]
266+ list of samples
267+ """
268+ print ("\n \n " )
269+ ### Form a union of each sample set, which is the subsampled strains list
270+ sampled_strains = set ()
271+ for sample in samples :
272+ print (f"Sample \" { sample .name } \" included { len (sample .sampled_strains )} strains" )
273+ sampled_strains .update (sample .sampled_strains )
274+ print (f"In total, { len (sampled_strains )} strains are included in the resulting subsampled dataset" )
275+
276+ ## Iterate through the input sequences, streaming a subsampled version to disk.
277+ sequences = read_sequences (args .alignment )
278+ sequences_written_to_disk = 0
279+ with open_file (args .output_fasta , "wt" ) as output_handle :
280+ for sequence in sequences :
281+ if sequence .id in sampled_strains :
282+ sequences_written_to_disk += 1
283+ write_sequences (sequence , output_handle , 'fasta' )
284+ print (f"Writing { sequences_written_to_disk } sequences to { args .output_fasta } " )
285+
286+ ## Iterate through the metadata in chunks, writing out those entries which are in the subsample
287+ metadata_reader = read_metadata (
288+ args .metadata ,
289+ id_columns = ["strain" , "name" ], # TODO - this should be an argument
290+ chunk_size = 10000 # TODO - argument
291+ )
292+ metadata_header = True
293+ metadata_mode = "w"
294+ metadata_written_to_disk = "???" # TODO
295+ for metadata in metadata_reader :
296+ metadata .loc [sampled_strains ].to_csv (
297+ args .output_metadata ,
298+ sep = "\t " ,
299+ header = metadata_header ,
300+ mode = metadata_mode ,
301+ )
302+ metadata_header = False
303+ metadata_mode = "a"
304+ print (f"Writing { metadata_written_to_disk } metadata entries sequences to { args .output_metadata } " )
305+
306+ ## Combine the log files (from augur filter) for each sample into a larger log file
307+ ## Format TBD
308+ ## TODO
309+
310+ if __name__ == "__main__" :
311+ # the format of this block is designed specifically for future transfer of this script
312+ # into augur in the form of `augur subsample`
313+ parser = ArgumentParser (
314+ usage = DESCRIPTION ,
315+ formatter_class = ArgumentDefaultsHelpFormatter ,
316+ )
317+ register_arguments (parser )
318+ args = parser .parse_args ()
319+ run (args )
0 commit comments