1+ from augur .utils import AugurException
2+ from augur .filter import run as augur_filter , register_arguments as register_filter_arguments
3+ from augur .index import index_sequences
4+ from augur .io import write_sequences , open_file , read_sequences , read_metadata
5+ from get_distance_to_focal_set import get_distance_to_focal_set # eventually from augur.priorities (or similar)
6+ from priorities import create_priorities # eventually from augur.priorities (or similar)
7+ import yaml
8+ from argparse import ArgumentParser , ArgumentDefaultsHelpFormatter
9+ from os import path
10+ import pandas as pd
11+ from tempfile import NamedTemporaryFile
12+ import jsonschema
13+ # from pkg_resources import resource_string
14+
15+ DESCRIPTION = "Subsample sequences based on user-defined YAML configuration"
16+
17+ def register_arguments (parser ):
18+ parser .add_argument ('--scheme' , required = True , metavar = "YAML" , help = "subsampling scheme" )
19+ parser .add_argument ('--output-dir' , required = True , metavar = "PATH" , help = "directory to save intermediate results" )
20+ parser .add_argument ('--metadata' , required = True , metavar = "TSV" , help = "metadata" )
21+ parser .add_argument ('--alignment' , required = True , metavar = "FASTA" , help = "alignment to subsample" )
22+ parser .add_argument ('--alignment-index' , required = False , metavar = "INDEX" , help = "sequence index of alignment" )
23+ parser .add_argument ('--reference' , required = True , metavar = "FASTA" , help = "reference (which was used for alignment)" )
24+ parser .add_argument ('--include-strains-file' , required = False , nargs = "+" , default = None , metavar = "TXT" , help = "strains to force include" )
25+ parser .add_argument ('--exclude-strains-file' , required = False , nargs = "+" , default = None , metavar = "TXT" , help = "strains to force exclude" )
26+ parser .add_argument ('--output-fasta' , required = True , metavar = "FASTA" , help = "output subsampled sequences" )
27+ parser .add_argument ('--output-metadata' , required = True , metavar = "TSV" , help = "output subsampled metadata" )
28+ parser .add_argument ('--output-log' , required = False , metavar = "TSV" , help = "log file explaining why strains were excluded / included" )
29+ parser .add_argument ('--use-existing-outputs' , required = False , action = "store_true" , help = "use intermediate files, if they exist" )
30+
31+ def run (args ):
32+
33+ config = parse_scheme (args .scheme )
34+
35+ generate_sequence_index (args )
36+
37+ samples = [Sample (name , data , args ) for name , data in config .items ()]
38+
39+ graph = make_graph (samples )
40+
41+ traverse_graph (
42+ graph ,
43+ lambda s : s .filter ()
44+ )
45+
46+ combine_samples (args , samples )
47+
48+ def parse_scheme (filename ):
49+ with open (filename ) as fh :
50+ try :
51+ data = yaml .safe_load (fh )
52+ except yaml .YAMLError as exc :
53+ print (exc )
54+ raise AugurException (f"Error parsing subsampling scheme { filename } " )
55+ validate_scheme (data )
56+ return data
57+
58+
59+ def validate_scheme (scheme ):
60+ try :
61+ # When we move this to `augur subsample`, load the schema via:
62+ # schema = yaml.safe_load(resource_string(__package__, path.join("data", "schema-subsampling.yaml")))
63+ with open (path .join (path .dirname (path .realpath (__file__ )), "subsample_schema.yaml" )) as fh :
64+ schema = yaml .safe_load (fh )
65+ except yaml .YAMLError as err :
66+ raise AugurException ("Subsampling schema definition is not valid YAML. Error: {}" .format (err ))
67+ # check loaded schema is itself valid -- see http://python-jsonschema.readthedocs.io/en/latest/errors/
68+ try :
69+ jsonschema .Draft6Validator .check_schema (schema )
70+ except jsonschema .exceptions .SchemaError as err :
71+ raise AugurException ("Subsampling schema definition is not valid. Error: {}" .format (path , err ))
72+
73+ try :
74+ jsonschema .Draft6Validator (schema ).validate (scheme )
75+ except jsonschema .exceptions .ValidationError as err :
76+ print (err )
77+ raise AugurException ("Subsampling scheme failed validation" )
78+
79+ class Sample ():
80+ """
81+ A class to hold information about a sample. A subsampling scheme will consist of multiple
82+ samples. Each sample may depend on the priorities based off another sample.
83+ """
84+ def __init__ (self , name , config , cmd_args ):
85+ self .name = name
86+ self .tmp_dir = cmd_args .output_dir
87+ self .alignment = cmd_args .alignment
88+ self .alignment_index = cmd_args .alignment_index
89+ self .reference = cmd_args .reference
90+ self .metadata = cmd_args .metadata
91+ self .initialise_filter_args (config , cmd_args )
92+ self .priorities = config .get ("priorities" , None )
93+ self .use_existing_outputs = args .use_existing_outputs
94+
95+ def initialise_filter_args (self , config , subsample_args ):
96+ """
97+ Currently this method is needed as we need to call `augur filter`'s `run()` with an
98+ argparse instance. An improvement here would be to expose appropriate filtering
99+ functions and call them as needed, with the output being returned rather than
100+ written to disk.
101+ """
102+ # create the appropriate command-line arguments for the augur filter run we want
103+ arg_list = [
104+ "--metadata" , self .metadata ,
105+ "--sequences" , self .alignment ,
106+ "--sequence-index" , self .alignment_index ,
107+ "--output" , path .join (self .tmp_dir , f"sample.{ self .name } .fasta" ), # filtered sequences in FASTA forma
108+ "--output-metadata" , path .join (self .tmp_dir , f"sample.{ self .name } .tsv" ), # metadata for strains that passed filters
109+ "--output-strains" , path .join (self .tmp_dir , f"sample.{ self .name } .txt" ), # list of strains that passed filters (no header)
110+ "--output-log" , path .join (self .tmp_dir , f"sample.{ self .name } .log.tsv" )
111+ ]
112+ # convert the YAML config into the command-line arguments for augur filter
113+ for name , value in config .items ():
114+ if isinstance (value , dict ):
115+ pass # we explicitly ignore dictionary config entries
116+ elif isinstance (value , list ):
117+ arg_list .append (f"--{ name } " )
118+ arg_list .extend ([str (v ) for v in value ])
119+ elif isinstance (value , bool ):
120+ if value :
121+ arg_list .append (f"--{ name } " )
122+ else :
123+ arg_list .append (f"--{ name } " )
124+ arg_list .append (str (value ))
125+ # mock an ArgumentParser so that we can use augur filters interface, avoiding the need to duplicate logic
126+ parser = ArgumentParser (prog = "Mock_Augur_Filter" )
127+ register_filter_arguments (parser )
128+ self .filter_args , unused_args = parser .parse_known_args (arg_list )
129+ if unused_args :
130+ print (f"Warning - the following config parameters are not part of augur filter and may be ignored:" )
131+ print (' ' .join (unused_args ))
132+
133+ def calculate_required_priorities (self ):
134+ """
135+ If computation of this sample requires priority information of another sample
136+ (the "focus"), then this function will compute those priorities by calling
137+ a method on the focal sample object.
138+ """
139+ if not self .priorities :
140+ return
141+ focal_sample = self .priorities .get ('sample' , None )
142+ if not focal_sample :
143+ raise AugurException (f"Cannot calculate priorities needed for { self .name } as the { self .get_priority_focus_name ()} sample wasn't linked" )
144+ print (f"Calculating priorities of { focal_sample .name } , as required by { self .name } " )
145+ priorities_file = focal_sample .calculate_priorities ()
146+ print (f"\t Setting { self .name } filter priority file to { priorities_file } " )
147+ self .filter_args .priority = priorities_file
148+
149+ def calculate_priorities (self ):
150+ """
151+ Calculate the priorities TSV file for samples in the alignment vs this sample
152+
153+ Returns the filename of the priorities file (TSV)
154+ """
155+
156+ proximity_output_file = path .join (self .tmp_dir , f"proximity_{ self .name } .tsv" )
157+ if self .use_existing_outputs and check_outputs_exist (proximity_output_file ):
158+ print (f"Using existing proximity scores for { self .name } " )
159+ else :
160+ print (f"Calculating proximity of { self .name } " )
161+ get_distance_to_focal_set (
162+ self .alignment ,
163+ self .reference ,
164+ self .filter_args .output ,
165+ proximity_output_file ,
166+ ignore_seqs = ["Wuhan/Hu-1/2019" ] # TODO - use the config to define this?
167+ )
168+
169+ priorities_path = path .join (self .tmp_dir , f"priorities_{ self .name } .tsv" )
170+ if self .use_existing_outputs and check_outputs_exist (priorities_path ):
171+ print (f"Using existing priorities for { self .name } " )
172+ else :
173+ print (f"Calculating priorities of { self .name } " )
174+ create_priorities (
175+ self .alignment_index ,
176+ proximity_output_file ,
177+ priorities_path
178+ )
179+ return priorities_path
180+
181+ def get_priority_focus_name (self ):
182+ if not self .priorities :
183+ return None
184+ return self .priorities ['focus' ]
185+
186+ def set_priority_sample (self , sample ):
187+ if not self .priorities :
188+ raise AugurException (f"No priorities set for { self .name } " )
189+ self .priorities ['sample' ] = sample
190+
191+ def filter (self ):
192+ print ("\n ---------------------------------\n CONSTRUCTING SAMPLE FOR" , self .name , "\n ---------------------------------" )
193+ self .calculate_required_priorities ()
194+ if self .use_existing_outputs and check_outputs_exist (self .filter_args .output_metadata , self .filter_args .output_strains , self .filter_args .output_log ):
195+ print (f"Using existing filtering results for { self .name } " )
196+ else :
197+ print ("Calling augur filter" )
198+ print ("Filter arguments:" )
199+ for k ,v in self .filter_args .__dict__ .items ():
200+ if v is not None :
201+ print (f"\t { k : <30} { v } " )
202+ augur_filter (self .filter_args )
203+
204+ # In the future, instead of `augur_filter` saving data to disk, it would return
205+ # data to the calling process. In lieu of that, we read the data just written.
206+ try :
207+ self .sampled_strains = set (pd .read_csv (self .filter_args .output_strains , header = None )[0 ])
208+ except pd .errors .EmptyDataError :
209+ self .sampled_strains = set ()
210+ self .filter_log = pd .read_csv (
211+ self .filter_args .output_log ,
212+ header = 0 ,
213+ sep = "\t " ,
214+ index_col = "strain"
215+ )
216+
217+
218+ def make_graph (samples ):
219+ """"
220+ Given a config file, construct a graph of samples to perform in an iterative fashion, such that
221+ priorities
222+ This is a DAG, however an extremely simple one which we can construct outselves rather than relying on
223+ extra libraries.
224+ Constraints:
225+ * Each sample can only use priorities of one other sample
226+ * Acyclic
227+ Structure:
228+ tuple: (sample name, list of descendent samples) where a "descendant" sample requires the linked sample to be
229+ created prior to it's creation. Each entry in the list has this tuple structure.
230+ """
231+
232+ included = set () # set of samples added to graph
233+ graph = (None , [])
234+
235+ # add all the samples which don't require priorities to the graph
236+ for sample in samples :
237+ if not sample .get_priority_focus_name ():
238+ graph [1 ].append ((sample , []))
239+ included .add (sample .name )
240+
241+ def add_descendants (level ):
242+ parent_sample = level [0 ]
243+ descendants = level [1 ]
244+ for sample in samples :
245+ if sample .name in included :
246+ continue
247+ if sample .get_priority_focus_name () == parent_sample .name :
248+ sample .set_priority_sample (parent_sample )
249+ descendants .append ((sample , []))
250+ included .add (sample .name )
251+ for inner_level in descendants :
252+ add_descendants (inner_level )
253+
254+ for level in graph [1 ]:
255+ add_descendants (level )
256+
257+ # from pprint import pprint
258+ # print("\ngraph"); pprint(graph);print("\n")
259+
260+ if len (samples )!= len (included ):
261+ AugurException ("Incomplete graph construction" )
262+
263+ return graph
264+
265+ def traverse_graph (level , callback ):
266+ this_sample , descendents = level
267+ if this_sample :
268+ callback (this_sample )
269+ for child in descendents :
270+ traverse_graph (child , callback )
271+
272+ def generate_sequence_index (args ):
273+ if args .alignment_index :
274+ print ("Skipping sequence index creation as an index was provided" )
275+ return
276+ print ("Creating ephemeral sequence index file" )
277+ with NamedTemporaryFile (delete = False ) as sequence_index_file :
278+ sequence_index_path = sequence_index_file .name
279+ index_sequences (args .alignment , sequence_index_path )
280+ args .alignment_index = sequence_index_path
281+
282+
283+ def combine_samples (args , samples ):
284+ """Collect the union of strains which are included in each sample and write them to disk.
285+ Parameters
286+ ----------
287+ args : argparse.Namespace
288+ Parsed arguments from argparse
289+ samples : list[Sample]
290+ list of samples
291+ """
292+ print ("\n \n " )
293+ ### Form a union of each sample set, which is the subsampled strains list
294+ sampled_strains = set ()
295+ for sample in samples :
296+ print (f"Sample \" { sample .name } \" included { len (sample .sampled_strains )} strains" )
297+ sampled_strains .update (sample .sampled_strains )
298+ print (f"In total, { len (sampled_strains )} strains are included in the resulting subsampled dataset" )
299+
300+ ## Iterate through the input sequences, streaming a subsampled version to disk.
301+ sequences = read_sequences (args .alignment )
302+ sequences_written_to_disk = 0
303+ with open_file (args .output_fasta , "wt" ) as output_handle :
304+ for sequence in sequences :
305+ if sequence .id in sampled_strains :
306+ sequences_written_to_disk += 1
307+ write_sequences (sequence , output_handle , 'fasta' )
308+ print (f"{ sequences_written_to_disk } sequences written to { args .output_fasta } " )
309+
310+ ## Iterate through the metadata in chunks, writing out those entries which are in the subsample
311+ metadata_reader = read_metadata (
312+ args .metadata ,
313+ id_columns = ["strain" , "name" ], # TODO - this should be an argument
314+ chunk_size = 10000 # TODO - argument
315+ )
316+ metadata_header = True
317+ metadata_mode = "w"
318+ metadata_written_to_disk = 0
319+ for metadata in metadata_reader :
320+ df = metadata .loc [metadata .index .intersection (sampled_strains )]
321+ df .to_csv (
322+ args .output_metadata ,
323+ sep = "\t " ,
324+ header = metadata_header ,
325+ mode = metadata_mode ,
326+ )
327+ metadata_written_to_disk += df .shape [0 ]
328+ metadata_header = False
329+ metadata_mode = "a"
330+ print (f"{ metadata_written_to_disk } metadata entries written to { args .output_metadata } " )
331+
332+ ## Combine the log files (from augur filter) for each sample into a larger log file
333+ ## Format TBD
334+ ## TODO
335+
336+ def check_outputs_exist (* paths ):
337+ for p in paths :
338+ if not (path .exists (p ) and path .isfile (p )):
339+ return False
340+ return True
341+
342+ if __name__ == "__main__" :
343+ # the format of this block is designed specifically for future transfer of this script
344+ # into augur in the form of `augur subsample`
345+ parser = ArgumentParser (
346+ usage = DESCRIPTION ,
347+ formatter_class = ArgumentDefaultsHelpFormatter ,
348+ )
349+ register_arguments (parser )
350+ args = parser .parse_args ()
351+ run (args )
0 commit comments