Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 105 additions & 109 deletions treetime/clock_filter_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,161 +35,157 @@ def residual_filter(tt, n_iqd):
def local_filter(tt, z_score_threshold):
tt.logger(f"TreeTime.ClockFilter: starting local_outlier_detection", 2)

node_info = collect_node_info(tt)
prepare_gauss_model(tt)
genome_wide_mu = tt.clock_model['slope']*tt.data.full_length
initial_params = {'mu':genome_wide_mu, 'sigma':3/genome_wide_mu, 'eps':0.5, 'Tc':None}
gauss_params = iterate_gauss_model(tt, initial_params)
tt.logger(f"TreeTime.ClockFilter: z-scale {gauss_params['z_scale']:1.2f}", 2)

node_info, z_scale = calculate_node_timings(tt, node_info)
tt.logger(f"TreeTime.ClockFilter: z-scale {z_scale:1.2f}", 2)

outliers = flag_outliers(tt, node_info, z_score_threshold, z_scale)
flag_outliers(tt, z_score_threshold, gauss_params['z_scale'])

outliers = []
for n in tt.tree.get_terminals():
if n.name in outliers:
if n.gauss_model['outlier']:
n.bad_branch = True
outliers.append({'tau':n.gauss_model['tau'], 'avg_date':np.mean(n.gauss_model['observations']), 'obs':n.gauss_model['observations'],
'z':n.gauss_model['z'], 'diagnosis':n.gauss_model['diagnosis'], 'name':n.name})

tt.outliers=None
if len(outliers):
outlier_df = pd.DataFrame(outliers).T.loc[:,['avg_date', 'tau', 'z', 'diagnosis']]\
outlier_df = pd.DataFrame(outliers).loc[:,['name', 'avg_date', 'tau', 'z', 'diagnosis']]\
.rename(columns={'avg_date':'given_date', 'tau':'apparent_date', 'z':'z-score'})
outlier_df.set_index('name', inplace=True)
tt.logger("Clock_filter.local_filter marked the following outliers", 2, warn=True)
if tt.verbose>=2:
print(outlier_df)
tt.outliers = outlier_df
return len(outliers)


def flag_outliers(tt, node_info, z_score_threshold, z_scale):
def add_outlier_info(z, n, n_info, parent_tau, mu):
n_info['z'] = z
def flag_outliers(tt, z_score_threshold, z_scale):
def add_outlier_info(z, n, parent_tau, mu):
gm = n.gauss_model
gm['z'] = z
diagnosis=''
# muts = n_info["nmuts"] if n.is_terminal() else 0.0
# parent_tau = node_info[n.up.name]['tau'] if n.is_terminal() else n_info['tau']
if z<0:
if np.abs(n_info['avg_date']-parent_tau) > n_info["nmuts"]/mu:
if np.abs(np.mean(gm['observations']-parent_tau)) > gm["n_muts"]/mu:
diagnosis='date_too_early'
else:
diagnosis = 'excess_mutations'
else:
diagnosis = 'date_too_late'
n_info['diagnosis'] = diagnosis
return n_info
gm['diagnosis'] = diagnosis
if diagnosis:
gm['outlier'] = True

outliers = {}
mu = tt.clock_model['slope']*tt.data.full_length
for n in tt.tree.get_terminals():
gm = n.gauss_model
if n.up.up is None:
continue # do not label children of the root as bad -- typically a problem with root choice that will be fixed anyway
n_info = node_info[n.name]
parent_tau = node_info[n.up.name]['tau']
if n_info['exact_date']:
z = (n_info['avg_date'] - n_info['tau'])/z_scale
parent_tau = n.up.gauss_model['tau']
if n.bad_branch:
add_outlier_info(0, n, parent_tau, mu)
elif gm['descendent_observations']>0:
z = (np.mean(gm['observations']) - gm['tau'])/z_scale
if np.abs(z) > z_score_threshold:
outliers[n.name] = add_outlier_info(z, n, n_info, parent_tau, mu)
add_outlier_info(z, n, parent_tau, mu)
elif n.raw_date_constraint and len(n.raw_date_constraint):
zs = [(n_info['tau'] - x)/z_scale for x in n.raw_date_constraint]
zs = [(gm['tau'] - x)/z_scale for x in n.raw_date_constraint]
if zs[0]*zs[1]>0 and np.min(np.abs(zs))>z_score_threshold:
outliers[n.name] = add_outlier_info(zs[0] if np.abs(zs[0])<np.abs(zs[1]) else zs[1],
n, n_info, parent_tau, mu)
add_outlier_info(zs[0] if np.abs(zs[0])<np.abs(zs[1]) else zs[1],
n, parent_tau, mu)

return outliers

def calculate_node_timings(tt, node_info, eps=0.2):
mu = tt.clock_model['slope']*tt.data.full_length
sigma_sq = (3/mu)**2
tt.logger(f"Clockfilter.calculate_node_timings: mu={mu:1.3e}/y, sigma={3/mu:1.3e}y", 2)
def iterate_gauss_model(tt, params):
from scipy.interpolate import interp1d
mu = params['mu']
sigma_sq = params['sigma']**2
eps = params['eps']
Tc = params['Tc']
with_coal = (Tc is not None) and ("lineages" in params)

# backward pass
for n in tt.tree.find_clades(order='postorder'):
p = node_info[n.name]
if not p['exact_date'] or p['skip']:
gm = n.gauss_model
if gm['descendent_observations']==0:
continue
if with_coal:
n_lineages = params['lineages'](gm['tau'])
coal= n_lineages/Tc*(len(n.clades)-1) + 1/Tc
else:
coal = 0

if n.is_terminal():
prefactor = (p["observations"]/sigma_sq + mu**2/(p["nmuts"]+eps))
p["a"] = (p["avg_date"]/sigma_sq + mu*p["nmuts"]/(p["nmuts"]+eps))/prefactor
prefactor = (gm["n_observations"]/sigma_sq + mu**2/(gm["n_muts"]+eps))
gm["a"] = (np.sum(gm["observations"])/sigma_sq + mu*gm["n_muts"]/(gm["n_muts"]+eps) + coal)/prefactor
else:
children = [node_info[c.name] for c in n if (not node_info[c.name]['skip']) and node_info[c.name]['exact_date']]
valid_children = [c.gauss_model for c in n.clades if c.gauss_model['descendent_observations']>0]
tmp_children_1 = mu*np.sum([(mu*c["a"]-c["n_muts"])/(eps+c["n_muts"]) for c in valid_children])
tmp_children_2 = mu**2*np.sum([(1-c["b"])/(eps+c["n_muts"]) for c in valid_children])
if n==tt.tree.root:
tmp_children_1 = mu*np.sum([(mu*c["a"]-c["nmuts"])/(eps+c["nmuts"]) for c in children])
tmp_children_2 = mu**2*np.sum([(1-c["b"])/(eps+c["nmuts"]) for c in children])
prefactor = (p["observations"]/sigma_sq + tmp_children_2)
p["a"] = (p["observations"]*p["avg_date"]/sigma_sq + tmp_children_1)/prefactor
prefactor = (gm["n_observations"]/sigma_sq + tmp_children_2)
gm["a"] = (np.sum(gm['observations'])/sigma_sq + tmp_children_1 + coal)/prefactor
else:
tmp_children_1 = mu*np.sum([(mu*c["a"]-c["nmuts"])/(eps+c["nmuts"]) for c in children])
tmp_children_2 = mu**2*np.sum([(1-c["b"])/(eps+c["nmuts"]) for c in children])
prefactor = (p["observations"]/sigma_sq + mu**2/(p["nmuts"]+eps) + tmp_children_2)
p["a"] = (p["observations"]*p["avg_date"]/sigma_sq + mu*p["nmuts"]/(p["nmuts"]+eps)+tmp_children_1)/prefactor
p["b"] = mu**2/(p["nmuts"]+eps)/prefactor

node_info[tt.tree.root.name]["tau"] = node_info[tt.tree.root.name]["a"]

## need to deal with tips without exact dates below.
prefactor = (gm["n_observations"]/sigma_sq + mu**2/(gm["n_muts"]+eps) + tmp_children_2)
gm["a"] = (np.sum(gm['observations'])/sigma_sq + mu*gm["n_muts"]/(gm["n_muts"]+eps)+tmp_children_1 + coal)/prefactor
if tt.tree.root!=n:
gm["b"] = mu**2/(gm["n_muts"]+eps)/prefactor

# evaluation of tau at the root
gm_r = tt.tree.root.gauss_model
gm_r["tau"] = gm_r["a"]
n_lineage_list = [(gm_r["tau"], len(tt.tree.root.clades))]
# forward pass
tsq = 0
dt = 0
dev = []
logL = np.log(2*np.pi*sigma_sq)*0.5*gm_r["descendent_observations"]
for n in tt.tree.get_nonterminals(order='preorder'):
p = node_info[n.name]
gm = n.gauss_model
for c in n:
c_info = node_info[c.name]
if c_info['skip']:
c_info['tau']=p['tau']
gm_c = c.gauss_model
if gm_c['descendent_observations']==0:
gm_c['tau']=gm['tau'] + gm_c['n_muts']/mu
else:
if c_info['exact_date']:
c_info["tau"] = c_info["a"] + c_info["b"]*p["tau"]
else:
c_info["tau"] = p["tau"] + c_info['nmuts']/mu
if c.is_terminal() and c_info['exact_date']:
dev.append(c_info['avg_date']-c_info['tau'])

return node_info, np.std(dev)


def collect_node_info(tt, percentile_for_exact_date=90):
node_info = {}
gm_c["tau"] = gm_c["a"] + gm_c["b"]*gm["tau"]
n_lineage_list.append((gm_c["tau"], len(c.clades)-1))
tsq += (gm_c["tau"] - gm["tau"])**2/(2*(gm_c["n_muts"]+eps))
dt += (gm_c["tau"] - gm["tau"])*gm_c["n_muts"]/(2*(gm_c["n_muts"]+eps))
if gm_c['n_observations']:
logL += 0.5*np.sum((gm_c['observations'] - gm_c['tau'])**2)/sigma_sq
dev.extend(gm_c['observations'] - gm_c['tau'])
logL += 0.5*((gm_c["tau"] - gm["tau"])*mu - gm_c["n_muts"])**2/(gm_c["n_muts"]+eps)
if with_coal:
n_lineages = params['lineages'](gm['tau'])
logL -= np.log(n_lineages/Tc)*(len(n.clades)-1)

n_lineage_list.sort()
n_lineage_list[-1] = (n_lineage_list[-1][0], 0)
return {'lineages': interp1d([x[0] for x in n_lineage_list], np.maximum(1, np.cumsum([x[1] for x in n_lineage_list])), kind='previous', bounds_error=False, fill_value=1),
'mu':dt/tsq, 'dmu':1/tsq**0.5, 'logL':logL, "sigma": params['sigma'], 'Tc':Tc, 'eps':eps, 'z_scale':np.std(dev)}


def prepare_gauss_model(tt, uncertainty_cutoff=0.2):
aln = tt.aln or False
if aln and (not tt.sequence_reconstruction):
tt.infer_ancestral_sequences(infer_gtr=False)
L = tt.data.full_length
for n in tt.tree.find_clades(order='postorder'):
n.gauss_model = {'a':0, 'b':0, 'tau':0, 'z':0, 'outlier':False}
n.gauss_model['n_muts'] = len([m for m in n.mutations if m[-1] in 'ACGT']) if aln \
else np.round(n.branch_length*L)
obs = []
if n.is_terminal() and (n.raw_date_constraint is not None) and (n.bad_branch is False):
if isinstance(n.raw_date_constraint, float):
obs.append(n.raw_date_constraint)
else: # interval
if n.raw_date_constraint[1]-n.raw_date_constraint[0]<=uncertainty_cutoff:
obs.append(np.mean(n.raw_date_constraint))

n.gauss_model['observations'] = np.array(obs)
n.gauss_model['n_observations'] = len(obs)
# total number of explicit observation of any downstream genotype that isn't filtered
n.gauss_model['descendent_observations'] = n.gauss_model['n_observations'] + np.sum([c.gauss_model['descendent_observations'] for c in n.clades])

date_uncertainty = [np.abs(n.raw_date_constraint[1]-n.raw_date_constraint[0])
if type(n.raw_date_constraint)!=float else 0.0
for n in tt.tree.get_terminals()
if n.raw_date_constraint is not None]
from scipy.stats import scoreatpercentile
uncertainty_cutoff = scoreatpercentile(date_uncertainty, percentile_for_exact_date)*1.01

for n in tt.tree.get_nonterminals(order='postorder'):
parent = {"dates": [], "tips": {}, "skip":False}
exact_dates = 0
for c in n:
if c.is_terminal():
child = {'skip':False}
child["nmuts"] = len([m for m in c.mutations if m[-1] in 'ACGT']) if aln \
else np.round(c.branch_length*L)
if c.raw_date_constraint is None:
child['exact_date'] = False
elif type(c.raw_date_constraint)==float:
child['exact_date'] = True
else:
child['exact_date'] = np.abs(c.raw_date_constraint[1]-c.raw_date_constraint[0])<=uncertainty_cutoff

if child['exact_date']:
exact_dates += 1
if child["nmuts"]==0:
child['skip'] = True
parent["tips"][c.name]={'date': np.mean(c.raw_date_constraint),
'exact_date':child['exact_date']}
else:
child['skip'] = False
child['observations'] = 1
if c.raw_date_constraint is not None:
child["avg_date"] = np.mean(c.raw_date_constraint)
node_info[c.name] = child
else:
if node_info[c.name]['exact_date']:
exact_dates += 1

parent['exact_date'] = exact_dates>0

parent["nmuts"] = len([m for m in n.mutations if m[-1] in 'ACGT']) if aln else np.round(n.branch_length*L)
d = [v['date'] for v in parent['tips'].values() if v['exact_date']]
parent["observations"] = len(d)
parent["avg_date"] = np.mean(d) if len(d) else 0.0
node_info[n.name] = parent

return node_info
Loading