Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring and Optimization #678

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions scripts/add_branch_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ def extract_spike_mutations(node_data):
return data

def extract_clade_labels(node_data):
data = {}
for name, node in node_data["nodes"].items():
if "clade_annotation" in node:
data[name] = node["clade_annotation"]
return data
return {
name: node["clade_annotation"]
for name, node in node_data["nodes"].items()
if "clade_annotation" in node
}

if __name__ == '__main__':
parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -44,14 +44,14 @@ def extract_clade_labels(node_data):

def attach_labels(n): # closure
if n["name"] in spike_mutations or n["name"] in clade_labels:
if "branch_attrs" not in n:
n["branch_attrs"]={}
if "labels" not in n["branch_attrs"]:
n["branch_attrs"]["labels"]={}
if n["name"] in spike_mutations:
n["branch_attrs"]["labels"]["spike_mutations"] = spike_mutations[n["name"]]
if n["name"] in clade_labels:
n["branch_attrs"]["labels"]["emerging_lineage"] = clade_labels[n["name"]]
if "branch_attrs" not in n:
n["branch_attrs"]={}
if "labels" not in n["branch_attrs"]:
n["branch_attrs"]["labels"]={}
if n["name"] in spike_mutations:
n["branch_attrs"]["labels"]["spike_mutations"] = spike_mutations[n["name"]]
if n["name"] in clade_labels:
n["branch_attrs"]["labels"]["emerging_lineage"] = clade_labels[n["name"]]

if "children" in n:
for c in n["children"]:
Expand Down
127 changes: 78 additions & 49 deletions scripts/add_labels.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,94 @@
import argparse
import json
from Bio import Phylo
from collections import defaultdict
from augur.utils import read_metadata
from Bio import SeqIO
import csv
import sys

def attach_labels(d, labeled_nodes):
if "children" in d:
for c in d["children"]:
if c["name"] in labeled_nodes:
if "labels" not in c["branch_attrs"]:
c["branch_attrs"]["labels"] = {}
c['branch_attrs']['labels']['mlabel'] = labeled_nodes[c["name"]][0]
print(c['branch_attrs']['labels'])
attach_labels(c, labeled_nodes)
EMPTY = ''

# This script was written in preparation for a future augur where commands
# may take multiple metadata files, thus making this script unnecessary!
#
# Merging logic:
# - Order of supplied TSVs matters
# - All columns are included (i.e. union of all columns present)
# - The last non-empty value read (from different TSVs) is used. I.e. values are overwritten.
# - Missing data is represented by an empty string
#
# We use one-hot encoding to specify which origin(s) a piece of metadata came from

if __name__ == '__main__':
def parse_args():
parser = argparse.ArgumentParser(
description="Remove extraneous colorings",
description="""
Custom script to combine metadata files from different origins.
In the case where metadata files specify different values, the latter provided file will take priority.
Columns will be added for each origin with values "yes" or "no" to identify the input source (origin) of each sample.
""",
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument('--metadata', required=True, nargs='+', metavar="TSV", help="Metadata files")
parser.add_argument('--origins', required=True, nargs='+', metavar="STR", help="Names of origins (order should match provided metadata)")
parser.add_argument('--output', required=True, metavar="TSV", help="Output (merged) metadata")
return parser.parse_args()

parser.add_argument('--input', type=str, metavar="JSON", required=True, help="input Auspice JSON")
parser.add_argument('--tree', type=str, required=True, help="tree file")
parser.add_argument('--clades', type=str, required=True, help="clades")
parser.add_argument('--mutations', type=str, required=True, help="mutations")
parser.add_argument('--output', type=str, metavar="JSON", required=True, help="output Auspice JSON")
args = parser.parse_args()

T = Phylo.read(args.tree, 'newick')
if __name__ == '__main__':
args = parse_args()
try:
assert(len(args.metadata)==len(args.origins))
assert(len(args.origins)>1)
except AssertionError:
print("Error. Please check your inputs - there must be the same number of metadata files as origins provided, and there must be more than one of each!")
sys.exit(2)

with open(args.mutations, "r") as f:
mutation_json = json.load(f)['nodes']
# READ IN METADATA FILES
metadata = []
for (origin, fname) in zip(args.origins, args.metadata):
data, columns = read_metadata(fname)
metadata.append({'origin': origin, "fname": fname, 'data': data, 'columns': columns, 'strains': {s for s in data.keys()}})

with open(args.clades, "r") as f:
clades_json = json.load(f)['nodes']
# SUMMARISE INPUT METADATA
print(f"Parsed {len(metadata)} metadata TSVs")
for m in metadata:
print(f"\t{m['origin']} ({m['fname']}): {len(m['data'].keys())} strains x {len(m['columns'])} columns")

with open(args.input, "r") as f:
input_json = json.load(f)
# BUILD UP COLUMN NAMES FROM MULTIPLE INPUTS TO PRESERVE ORDER
combined_columns = []
for m in metadata:
combined_columns.extend([c for c in m['columns'] if c not in combined_columns])
combined_columns.extend(list(args.origins))

nodes = {}
for n in T.find_clades(order='postorder'):
if n.is_terminal():
n.tip_count=1
else:
n.tip_count = sum([c.tip_count for c in n])
nodes[n.name] = {'tip_count':n.tip_count}
# ADD IN VALUES ONE BY ONE, OVERWRITING AS NECESSARY
combined_data = metadata[0]['data']
for strain in combined_data:
for column in combined_columns:
if column not in combined_data[strain]:
combined_data[strain][column] = EMPTY

labels = defaultdict(list)
for node in nodes:
for m in mutation_json[node]['muts']:
if m[0] in 'ACGT' and m[-1] in 'ACGT':
clade = clades_json[node]['clade_membership']
tmp_label = (clade, m)
labels[tmp_label].append((node, nodes[node]['tip_count']))
for idx in range(1, len(metadata)):
for strain, row in metadata[idx]['data'].items():
if strain not in combined_data:
combined_data[strain] = {c:EMPTY for c in combined_columns}
for column in combined_columns:
if column in row:
existing_value = combined_data[strain][column]
new_value = row[column]
# overwrite _ANY_ existing value if the overwriting value is non empty (and different)!
if new_value != EMPTY and new_value != existing_value:
if existing_value != EMPTY:
print(f"[{strain}::{column}] Overwriting {combined_data[strain][column]} with {new_value}")
combined_data[strain][column] = new_value

labeled_nodes = defaultdict(list)
for label in labels:
node = sorted(labels[label], key=lambda x:-x[1])[0]
labeled_nodes[node[0]].append('/'.join(label))
# one-hot encoding for origin
# note that we use "yes" / "no" here as Booleans are problematic for `augur filter`
for metadata_entry in metadata:
origin = metadata_entry['origin']
for strain in combined_data:
combined_data[strain][origin] = "yes" if strain in metadata_entry['strains'] else "no"

attach_labels(input_json["tree"], labeled_nodes)
print(f"Combined metadata: {len(combined_data.keys())} strains x {len(combined_columns)} columns")

with open(args.output, 'w') as f:
json.dump(input_json, f, indent=2)
with open(args.output, 'w') as fh:
tsv_writer = csv.writer(fh, delimiter='\t')
tsv_writer.writerow(combined_columns)
for row in combined_data.values():
tsv_writer.writerow([row[column] for column in combined_columns])
Loading