1+ from augur .utils import AugurException
2+ from augur .filter import run as augur_filter
3+ from augur .index import index_sequences
4+ from augur .io import write_sequences , open_file , read_sequences , read_metadata
5+ from get_distance_to_focal_set import get_distance_to_focal_set # eventually from augur.priorities (or similar)
6+ from priorities import create_priorities # eventually from augur.priorities (or similar)
7+ import yaml
8+ from argparse import Namespace , ArgumentParser , ArgumentDefaultsHelpFormatter
9+ from os import path
10+ import pandas as pd
11+ from tempfile import NamedTemporaryFile
12+ import jsonschema
13+ # from pkg_resources import resource_string
14+
15+ DESCRIPTION = "Subsample sequences based on user-defined YAML configuration"
16+
17+ def register_arguments (parser ):
18+ parser .add_argument ('--scheme' , required = True , metavar = "YAML" , help = "subsampling scheme" )
19+ parser .add_argument ('--output-dir' , required = True , metavar = "PATH" , help = "directory to save intermediate results" )
20+ parser .add_argument ('--metadata' , required = True , metavar = "TSV" , help = "metadata" )
21+ parser .add_argument ('--alignment' , required = True , metavar = "FASTA" , help = "alignment to subsample" )
22+ parser .add_argument ('--alignment-index' , required = False , metavar = "INDEX" , help = "sequence index of alignment" )
23+ parser .add_argument ('--reference' , required = True , metavar = "FASTA" , help = "reference (which was used for alignment)" )
24+ parser .add_argument ('--include-strains-file' , required = False , nargs = "+" , default = None , metavar = "TXT" , help = "strains to force include" )
25+ parser .add_argument ('--exclude-strains-file' , required = False , nargs = "+" , default = None , metavar = "TXT" , help = "strains to force exclude" )
26+ parser .add_argument ('--output-fasta' , required = True , metavar = "FASTA" , help = "output subsampled sequences" )
27+ parser .add_argument ('--output-metadata' , required = True , metavar = "TSV" , help = "output subsampled metadata" )
28+ parser .add_argument ('--output-log' , required = False , metavar = "TSV" , help = "log file explaining why strains were excluded / included" )
29+ parser .add_argument ('--use-existing-outputs' , required = False , action = "store_true" , help = "use intermediate files, if they exist" )
30+
31+ def run (args ):
32+
33+ config = parse_scheme (args .scheme )
34+
35+ generate_sequence_index (args )
36+
37+ samples = [Sample (name , data , args ) for name , data in config .items ()]
38+
39+ graph = make_graph (samples )
40+
41+ traverse_graph (
42+ graph ,
43+ lambda s : s .filter ()
44+ )
45+
46+ combine_samples (args , samples )
47+
48+ def parse_scheme (filename ):
49+ with open (filename ) as fh :
50+ try :
51+ data = yaml .safe_load (fh )
52+ except yaml .YAMLError as exc :
53+ print (exc )
54+ raise AugurException (f"Error parsing subsampling scheme { filename } " )
55+ validate_scheme (data )
56+ return data
57+
58+
59+ def validate_scheme (scheme ):
60+ try :
61+ # When we move this to `augur subsample`, load the schema via:
62+ # schema = yaml.safe_load(resource_string(__package__, path.join("data", "schema-subsampling.yaml")))
63+ with open (path .join (path .dirname (path .realpath (__file__ )), "subsample_schema.yaml" )) as fh :
64+ schema = yaml .safe_load (fh )
65+ except yaml .YAMLError as err :
66+ raise AugurException ("Subsampling schema definition is not valid YAML. Error: {}" .format (err ))
67+ # check loaded schema is itself valid -- see http://python-jsonschema.readthedocs.io/en/latest/errors/
68+ try :
69+ jsonschema .Draft6Validator .check_schema (schema )
70+ except jsonschema .exceptions .SchemaError as err :
71+ raise AugurException ("Subsampling schema definition is not valid. Error: {}" .format (path , err ))
72+
73+ try :
74+ jsonschema .Draft6Validator (schema ).validate (scheme )
75+ except jsonschema .exceptions .ValidationError as err :
76+ print (err )
77+ raise AugurException ("Subsampling scheme failed validation" )
78+
79+ class Sample ():
80+ """
81+ A class to hold information about a sample. A subsampling scheme will consist of multiple
82+ samples. Each sample may depend on the priorities based off another sample.
83+ """
84+ def __init__ (self , name , config , cmd_args ):
85+ self .name = name
86+ self .tmp_dir = cmd_args .output_dir
87+ self .alignment = cmd_args .alignment
88+ self .alignment_index = cmd_args .alignment_index
89+ self .reference = cmd_args .reference
90+ self .metadata = cmd_args .metadata
91+ self .initialise_filter_args (config , cmd_args )
92+ self .priorities = config .get ("priorities" , None )
93+ self .use_existing_outputs = args .use_existing_outputs
94+ print ("Constructor" , self .name )
95+
96+ def initialise_filter_args (self , config , subsample_args ):
97+ """
98+ Currently this method is needed as we need to call `augur filter`'s `run()` with an
99+ argparse instance. An improvement here would be to expose appropriate filtering
100+ functions and call them as needed, with the output being returned rather than
101+ written to disk.
102+ """
103+ args = Namespace ()
104+ args .metadata = self .metadata
105+ args .sequences = self .alignment
106+ args .sequence_index = self .alignment_index
107+ args .metadata_chunk_size = 100000
108+ args .metadata_id_columns = ["strain" , "name" ]
109+ args .min_date = None
110+ args .max_date = None
111+ args .group_by = config .get ('group-by' , None )
112+ args .sequences_per_group = config .get ("sequences-per-group" , None )
113+ args .subsample_max_sequences = config .get ("subsample-max-sequences" , None )
114+ args .exclude_ambiguous_dates_by = None
115+ args .exclude = subsample_args .exclude_strains_file
116+ args .exclude_all = None
117+ args .exclude_where = config .get ('exclude-where' , None )
118+ args .include = subsample_args .include_strains_file
119+ args .include_where = None
120+ args .min_length = None
121+ args .non_nucleotide = None
122+ args .probabilistic_sampling = config .get ("probabilistic-sampling" , None )
123+ args .no_probabilistic_sampling = config .get ("no-probabilistic-sampling" , None )
124+ args .priority = None
125+ args .subsample_seed = None
126+ args .query = config .get ("query" , None )
127+ args .output = path .join (self .tmp_dir , f"sample.{ self .name } .fasta" ) # filtered sequences in FASTA forma
128+ args .output_metadata = path .join (self .tmp_dir , f"sample.{ self .name } .tsv" ) # metadata for strains that passed filters
129+ args .output_strains = path .join (self .tmp_dir , f"sample.{ self .name } .txt" ) # list of strains that passed filters (no header)
130+ args .output_log = path .join (self .tmp_dir , f"sample.{ self .name } .log.tsv" )
131+ self .filter_args = args
132+
133+
134+ def calculate_required_priorities (self ):
135+ """
136+ If computation of this sample requires priority information of another sample
137+ (the "focus"), then this function will compute those priorities.
138+ """
139+ if not self .priorities :
140+ return
141+ focal_sample = self .priorities .get ('sample' , None )
142+ if not focal_sample :
143+ raise AugurException (f"Cannot calculate priorities needed for { self .name } as the { self .get_priority_focus_name ()} sample wasn't linked" )
144+ print (f"Calculating priorities required by { self .name } " )
145+ priorities_file = focal_sample .calculate_priorities ()
146+ self .filter_args .priority = priorities_file
147+
148+ def calculate_priorities (self ):
149+ """
150+ Calculate the priorities TSV file for the alignment in the context of this sample
151+
152+ Returns the filename of the priorities file (TSV)
153+ """
154+
155+ proximity_output_file = path .join (self .tmp_dir , f"proximity_{ self .name } .tsv" )
156+ if self .use_existing_outputs and check_outputs_exist (proximity_output_file ):
157+ print (f"Using existing proximity scores for { self .name } " )
158+ else :
159+ print (f"Calculating proximity of { self .name } " )
160+ get_distance_to_focal_set (
161+ self .alignment ,
162+ self .reference ,
163+ self .filter_args .output ,
164+ proximity_output_file ,
165+ ignore_seqs = ["Wuhan/Hu-1/2019" ] # TODO - use the config to define this?
166+ )
167+
168+ priorities_path = path .join (self .tmp_dir , f"priorities_{ self .name } .tsv" )
169+ if self .use_existing_outputs and check_outputs_exist (priorities_path ):
170+ print (f"Using existing priorities for { self .name } " )
171+ else :
172+ print (f"Calculating priorities of { self .name } " )
173+ create_priorities (
174+ self .alignment_index ,
175+ proximity_output_file ,
176+ priorities_path
177+ )
178+ return priorities_path
179+
180+ def get_priority_focus_name (self ):
181+ if not self .priorities :
182+ return None
183+ return self .priorities ['focus' ]
184+
185+ def set_priority_sample (self , sample ):
186+ if not self .priorities :
187+ raise AugurException (f"No priorities set for { self .name } " )
188+ self .priorities ['sample' ] = sample
189+
190+ def filter (self ):
191+ print ("\n ---------------------------------\n CONSTRUCTING SAMPLE FOR" , self .name , "\n ---------------------------------" )
192+ self .calculate_required_priorities ()
193+
194+ if self .use_existing_outputs and check_outputs_exist (self .filter_args .output_metadata , self .filter_args .output_strains , self .filter_args .output_log ):
195+ print (f"Using existing filtering results for { self .name } " )
196+ else :
197+ augur_filter (self .filter_args )
198+
199+ # In the future, instead of `augur_filter` saving data to disk, it would return
200+ # data to the calling process. In lieu of that, we read the data just written.
201+ self .sampled_strains = set (pd .read_csv (self .filter_args .output_strains , header = None )[0 ])
202+ self .filter_log = pd .read_csv (
203+ self .filter_args .output_log ,
204+ header = 0 ,
205+ sep = "\t " ,
206+ index_col = "strain"
207+ )
208+
209+
210+ def make_graph (samples ):
211+ """"
212+ Given a config file, construct a graph of samples to perform in an iterative fashion, such that
213+ priorities
214+ This is a DAG, however an extremely simple one which we can construct outselves rather than relying on
215+ extra libraries.
216+ Constraints:
217+ * Each sample can only use priorities of one other sample
218+ * Acyclic
219+ Structure:
220+ tuple: (sample name, list of descendent samples) where a "descendant" sample requires the linked sample to be
221+ created prior to it's creation. Each entry in the list has this tuple structure.
222+ """
223+
224+ included = set () # set of samples added to graph
225+ graph = (None , [])
226+
227+ # add all the samples which don't require priorities to the graph
228+ for sample in samples :
229+ if not sample .get_priority_focus_name ():
230+ graph [1 ].append ((sample , []))
231+ included .add (sample .name )
232+
233+ def add_descendants (level ):
234+ parent_sample = level [0 ]
235+ descendants = level [1 ]
236+ for sample in samples :
237+ if sample .name in included :
238+ continue
239+ if sample .get_priority_focus_name () == parent_sample .name :
240+ sample .set_priority_sample (parent_sample )
241+ descendants .append ((sample , []))
242+ included .add (sample .name )
243+ for inner_level in descendants :
244+ add_descendants (inner_level )
245+
246+ for level in graph [1 ]:
247+ add_descendants (level )
248+
249+ # from pprint import pprint
250+ # print("\ngraph"); pprint(graph);print("\n")
251+
252+ if len (samples )!= len (included ):
253+ AugurException ("Incomplete graph construction" )
254+
255+ return graph
256+
257+ def traverse_graph (level , callback ):
258+ this_sample , descendents = level
259+ if this_sample :
260+ callback (this_sample )
261+ for child in descendents :
262+ traverse_graph (child , callback )
263+
264+ def generate_sequence_index (args ):
265+ if args .alignment_index :
266+ print ("Skipping sequence index creation as an index was provided" )
267+ return
268+ print ("Creating ephemeral sequence index file" )
269+ with NamedTemporaryFile (delete = False ) as sequence_index_file :
270+ sequence_index_path = sequence_index_file .name
271+ index_sequences (args .alignment , sequence_index_path )
272+ args .alignment_index = sequence_index_path
273+
274+
275+ def combine_samples (args , samples ):
276+ """Collect the union of strains which are included in each sample and write them to disk.
277+ Parameters
278+ ----------
279+ args : argparse.Namespace
280+ Parsed arguments from argparse
281+ samples : list[Sample]
282+ list of samples
283+ """
284+ print ("\n \n " )
285+ ### Form a union of each sample set, which is the subsampled strains list
286+ sampled_strains = set ()
287+ for sample in samples :
288+ print (f"Sample \" { sample .name } \" included { len (sample .sampled_strains )} strains" )
289+ sampled_strains .update (sample .sampled_strains )
290+ print (f"In total, { len (sampled_strains )} strains are included in the resulting subsampled dataset" )
291+
292+ ## Iterate through the input sequences, streaming a subsampled version to disk.
293+ sequences = read_sequences (args .alignment )
294+ sequences_written_to_disk = 0
295+ with open_file (args .output_fasta , "wt" ) as output_handle :
296+ for sequence in sequences :
297+ if sequence .id in sampled_strains :
298+ sequences_written_to_disk += 1
299+ write_sequences (sequence , output_handle , 'fasta' )
300+ print (f"{ sequences_written_to_disk } sequences written to { args .output_fasta } " )
301+
302+ ## Iterate through the metadata in chunks, writing out those entries which are in the subsample
303+ metadata_reader = read_metadata (
304+ args .metadata ,
305+ id_columns = ["strain" , "name" ], # TODO - this should be an argument
306+ chunk_size = 10000 # TODO - argument
307+ )
308+ metadata_header = True
309+ metadata_mode = "w"
310+ metadata_written_to_disk = 0
311+ for metadata in metadata_reader :
312+ df = metadata .loc [metadata .index .intersection (sampled_strains )]
313+ df .to_csv (
314+ args .output_metadata ,
315+ sep = "\t " ,
316+ header = metadata_header ,
317+ mode = metadata_mode ,
318+ )
319+ metadata_written_to_disk += df .shape [0 ]
320+ metadata_header = False
321+ metadata_mode = "a"
322+ print (f"{ metadata_written_to_disk } metadata entries written to { args .output_metadata } " )
323+
324+ ## Combine the log files (from augur filter) for each sample into a larger log file
325+ ## Format TBD
326+ ## TODO
327+
328+ def check_outputs_exist (* paths ):
329+ for p in paths :
330+ if not (path .exists (p ) and path .isfile (p )):
331+ return False
332+ return True
333+
334+ if __name__ == "__main__" :
335+ # the format of this block is designed specifically for future transfer of this script
336+ # into augur in the form of `augur subsample`
337+ parser = ArgumentParser (
338+ usage = DESCRIPTION ,
339+ formatter_class = ArgumentDefaultsHelpFormatter ,
340+ )
341+ register_arguments (parser )
342+ args = parser .parse_args ()
343+ run (args )
0 commit comments