Skip to content

Commit efc5266

Browse files
committed
supported --fg_format 2 and --more_tree_plot. Use of float64 for csubst simulate. --fg_clade_permutation is broken.
1 parent a7e890a commit efc5266

File tree

9 files changed

+569
-425
lines changed

9 files changed

+569
-425
lines changed

csubst/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.3.19'
1+
__version__ = '1.4.0'

csubst/combination.py

Lines changed: 112 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@ def node_union(index_combinations, target_nodes, df_mmap, mmap_start):
1919
df_mmap[i, :] = node_union
2020
i += 1
2121

22-
def nc_matrix2id_combinations(nc_matrix, arity, ncpu, verbose):
23-
start = time.time()
22+
def nc_matrix2id_combinations(nc_matrix, arity, ncpu):
2423
rows, cols = numpy.where(numpy.equal(nc_matrix, 1))
2524
unique_cols = numpy.unique(cols)
2625
ind2 = numpy.arange(arity, dtype=numpy.int64)
@@ -31,12 +30,10 @@ def nc_matrix2id_combinations(nc_matrix, arity, ncpu, verbose):
3130
(chunk, start, arity, ind2, rows, cols, unique_cols) for chunk, start in zip(chunks, starts)
3231
)
3332
id_combinations = numpy.concatenate(out)
34-
if verbose:
35-
print('Time elapsed for generating branch combinations: {:,} sec'.format(int(time.time() - start)))
3633
return id_combinations
3734

3835
def get_node_combinations(g, target_nodes=None, arity=2, check_attr=None, verbose=True):
39-
g['fg_dependent_id_combinations'] = None
36+
g['fg_dependent_id_combinations'] = dict()
4037
tree = g['tree']
4138
all_nodes = [ node for node in tree.traverse() if not node.is_root() ]
4239
if verbose:
@@ -48,8 +45,37 @@ def get_node_combinations(g, target_nodes=None, arity=2, check_attr=None, verbos
4845
target_nodes.append(node.numerical_label)
4946
target_nodes = numpy.array(target_nodes)
5047
node_combinations = list(itertools.combinations(target_nodes, arity))
51-
node_combinations = [set(nc) for nc in node_combinations]
52-
node_combinations = numpy.array([list(nc) for nc in node_combinations])
48+
node_combinations = [ set(nc) for nc in node_combinations ]
49+
node_combinations = numpy.array([ list(nc) for nc in node_combinations ])
50+
elif isinstance(target_nodes, dict):
51+
trait_names = list(target_nodes.keys())
52+
node_combination_dict = dict()
53+
for trait_name in trait_names:
54+
if (target_nodes[trait_name].shape.__len__()==1):
55+
target_nodes[trait_name] = numpy.expand_dims(target_nodes[trait_name], axis=1)
56+
index_combinations = list(itertools.combinations(numpy.arange(target_nodes[trait_name].shape[0]), 2))
57+
if len(index_combinations)==0:
58+
sys.stderr.write('There is no target branch combination at K = {:,}.\n'.format(arity))
59+
id_combinations = numpy.zeros(shape=[0,arity], dtype=numpy.int64)
60+
return g, id_combinations
61+
if verbose:
62+
txt = 'Number of branch combinations before independency check for {}: {:,}'
63+
print(txt.format(trait_name, len(index_combinations)), flush=True)
64+
axis = (len(index_combinations), arity)
65+
mmap_out = os.path.join(os.getcwd(), 'tmp.csubst.node_combinations.mmap')
66+
if os.path.exists(mmap_out): os.unlink(mmap_out)
67+
df_mmap = numpy.memmap(mmap_out, dtype=numpy.int32, shape=axis, mode='w+')
68+
chunks,starts = parallel.get_chunks(index_combinations, g['threads'])
69+
joblib.Parallel(n_jobs=g['threads'], max_nbytes=None, backend='multiprocessing')(
70+
joblib.delayed(node_union)
71+
(ids, target_nodes[trait_name], df_mmap, ms) for ids, ms in zip(chunks, starts)
72+
)
73+
is_valid_combination = (df_mmap.sum(axis=1)!=0)
74+
if (is_valid_combination.sum()>0):
75+
node_combination_dict[trait_name] = numpy.unique(df_mmap[is_valid_combination,:], axis=0)
76+
else:
77+
node_combination_dict[trait_name] = numpy.zeros(shape=[0,arity], dtype=numpy.int64)
78+
node_combinations = numpy.unique(numpy.concatenate(list(node_combination_dict.values()), axis=0), axis=0)
5379
elif isinstance(target_nodes, numpy.ndarray):
5480
if (target_nodes.shape.__len__()==1):
5581
target_nodes = numpy.expand_dims(target_nodes, axis=1)
@@ -74,10 +100,10 @@ def get_node_combinations(g, target_nodes=None, arity=2, check_attr=None, verbos
74100
node_combinations = numpy.unique(df_mmap[is_valid_combination,:], axis=0)
75101
else:
76102
node_combinations = numpy.zeros(shape=[0,arity], dtype=numpy.int64)
103+
else:
104+
raise Exception('target_nodes must be either None, dict, or numpy.ndarray.')
77105
if verbose:
78-
num_target_node = numpy.unique(target_nodes.flatten()).shape[0]
79-
print("Number of target branches: {:,}".format(num_target_node), flush=True)
80-
print("Number of independent/non-independent branch combinations: {:,}".format(node_combinations.shape[0]), flush=True)
106+
print("Number of all branch combinations before independency check: {:,}".format(node_combinations.shape[0]), flush=True)
81107
nc_matrix = numpy.zeros(shape=(len(all_nodes), node_combinations.shape[0]), dtype=bool)
82108
for i in numpy.arange(node_combinations.shape[0]):
83109
nc_matrix[node_combinations[i,:],i] = True
@@ -87,26 +113,39 @@ def get_node_combinations(g, target_nodes=None, arity=2, check_attr=None, verbos
87113
if verbose:
88114
print('Number of non-independent branch combinations to be removed: {:,}'.format(is_dependent_col.sum()), flush=True)
89115
nc_matrix = nc_matrix[:,~is_dependent_col]
90-
if (g['foreground'] is not None)&(len(g['fg_dep_ids']) > 0):
91-
is_fg_dependent_col = False
92-
for fg_dep_id in g['fg_dep_ids']:
116+
id_combinations = numpy.zeros(shape=(0,arity), dtype=numpy.int64)
117+
start = time.time()
118+
trait_names = g['fg_df'].columns[1:len(g['fg_df'].columns)].tolist()
119+
for trait_name in trait_names:
120+
if verbose:
121+
if isinstance(target_nodes, dict):
122+
num_target_node = numpy.unique(target_nodes[trait_name].flatten()).shape[0]
123+
else:
124+
num_target_node = numpy.unique(target_nodes.flatten()).shape[0]
125+
print("Number of target branches: {:,}".format(num_target_node), flush=True)
126+
is_fg_dependent_col = numpy.zeros(shape=(nc_matrix.shape[1],), dtype=bool)
127+
for fg_dep_id in g['fg_dep_ids'][trait_name]:
93128
is_fg_dependent_col |= (nc_matrix[fg_dep_id, :].sum(axis=0) > 1)
94129
if (g['exhaustive_until']>=arity):
95130
if verbose:
96-
txt = 'Detected {:,} (out of {:,}) foreground branch combinations to be treated as non-foreground '
97-
txt += '(e.g., parent-child pairs).'
98-
print(txt.format(is_fg_dependent_col.sum(), is_fg_dependent_col.shape[0]), flush=True)
131+
txt = 'Number of non-independent foreground branch combinations to be non-foreground-marked for {}: {:,} / {:,}'
132+
print(txt.format(trait_name, is_fg_dependent_col.sum(), is_fg_dependent_col.shape[0]), flush=True)
99133
fg_dep_nc_matrix = numpy.copy(nc_matrix)
100134
fg_dep_nc_matrix[:,~is_fg_dependent_col] = False
101-
g['fg_dependent_id_combinations'] = nc_matrix2id_combinations(fg_dep_nc_matrix, arity, g['threads'], verbose)
135+
g['fg_dependent_id_combinations'][trait_name] = nc_matrix2id_combinations(fg_dep_nc_matrix, arity, g['threads'])
136+
if trait_name == trait_names[0]:
137+
id_combinations = nc_matrix2id_combinations(nc_matrix, arity, g['threads'])
102138
else:
103139
if verbose:
104-
txt = 'removing {:,} (out of {:,}) dependent foreground branch combinations.'
105-
print(txt.format(is_fg_dependent_col.sum(), is_fg_dependent_col.shape[0]), flush=True)
140+
txt = 'Removing {:,} (out of {:,}) non-independent foreground branch combinations for {}.'
141+
print(txt.format(is_fg_dependent_col.sum(), is_fg_dependent_col.shape[0], trait_name), flush=True)
106142
nc_matrix = nc_matrix[:,~is_fg_dependent_col]
107-
id_combinations = nc_matrix2id_combinations(nc_matrix, arity, g['threads'], verbose)
143+
g['fg_dependent_id_combinations'][trait_name] = numpy.array([])
144+
trait_id_combinations = nc_matrix2id_combinations(nc_matrix, arity, g['threads'])
145+
id_combinations = numpy.unique(numpy.concatenate((id_combinations, trait_id_combinations), axis=0), axis=0)
108146
if verbose:
109-
print("Number of independent branch combinations: {:,}".format(id_combinations.shape[0]), flush=True)
147+
print('Time elapsed for generating branch combinations: {:,} sec'.format(int(time.time() - start)))
148+
print("Number of independent branch combinations to be analyzed: {:,}".format(id_combinations.shape[0]), flush=True)
110149
return g,id_combinations
111150

112151
def node_combination_subsamples_rifle(g, arity, rep):
@@ -198,58 +237,67 @@ def calc_substitution_patterns(cb):
198237
cb.loc[:,key+'_pattern_id'] = sub_patterns4.loc[:,'sub_pattern_id']
199238
return cb
200239

201-
def get_dep_ids(g):
202-
dep_ids = list()
240+
def get_global_dep_ids(g):
241+
global_dep_ids = list()
203242
for leaf in g['tree'].iter_leaves():
204-
ancestor_nns = [ node.numerical_label for node in leaf.iter_ancestors() if not node.is_root() ]
205-
dep_id = [leaf.numerical_label,] + ancestor_nns
243+
ancestor_nns = [node.numerical_label for node in leaf.iter_ancestors() if not node.is_root()]
244+
dep_id = [leaf.numerical_label, ] + ancestor_nns
206245
dep_id = numpy.sort(numpy.array(dep_id))
207-
dep_ids.append(dep_id)
208-
if g['exclude_sister_pair']:
209-
for node in g['tree'].traverse():
210-
children = node.get_children()
211-
if len(children)>1:
212-
dep_id = numpy.sort(numpy.array([ node.numerical_label for node in children ]))
213-
dep_ids.append(dep_id)
246+
global_dep_ids.append(dep_id)
247+
if g['exclude_sister_pair']:
248+
for node in g['tree'].traverse():
249+
children = node.get_children()
250+
if len(children)>1:
251+
dep_id = numpy.sort(numpy.array([ node.numerical_label for node in children ]))
252+
global_dep_ids.append(dep_id)
214253
root_nn = g['tree'].numerical_label
215-
root_state_sum = g['state_cdn'][root_nn,:,:].sum()
216-
if (root_state_sum==0):
254+
root_state_sum = g['state_cdn'][root_nn, :, :].sum()
255+
if (root_state_sum == 0):
217256
print('Ancestral states were not estimated on the root node. Excluding sub-root nodes from the analysis.')
218-
subroot_nns = [ node.numerical_label for node in g['tree'].get_children() ]
257+
subroot_nns = [node.numerical_label for node in g['tree'].get_children()]
219258
for subroot_nn in subroot_nns:
220259
for node in g['tree'].traverse():
221260
if node.is_root():
222261
continue
223-
if subroot_nn==node.numerical_label:
262+
if subroot_nn == node.numerical_label:
224263
continue
225-
ancestor_nns = [ anc.numerical_label for anc in node.iter_ancestors() ]
264+
ancestor_nns = [anc.numerical_label for anc in node.iter_ancestors()]
226265
if subroot_nn in ancestor_nns:
227266
continue
228-
dep_ids.append(numpy.array([subroot_nn, node.numerical_label]))
229-
g['dep_ids'] = dep_ids
230-
if (g['foreground'] is not None)&(g['fg_exclude_wg']):
231-
fg_dep_ids = list()
232-
for i in numpy.arange(len(g['fg_leaf_name'])):
233-
tmp_fg_dep_ids = list()
234-
for node in g['tree'].traverse():
235-
if node.is_root():
236-
continue
237-
is_all_leaf_lineage_fg = all([ ln in g['fg_leaf_name'][i] for ln in node.get_leaf_names() ])
238-
if not is_all_leaf_lineage_fg:
239-
continue
240-
is_up_all_leaf_lineage_fg = all([ ln in g['fg_leaf_name'][i] for ln in node.up.get_leaf_names() ])
241-
if is_up_all_leaf_lineage_fg:
242-
continue
243-
if node.is_leaf():
244-
tmp_fg_dep_ids.append(node.numerical_label)
245-
else:
246-
descendant_nn = [ n.numerical_label for n in node.get_descendants() ]
247-
tmp_fg_dep_ids += [node.numerical_label,] + descendant_nn
248-
if len(tmp_fg_dep_ids)>1:
249-
fg_dep_ids.append(numpy.sort(numpy.array(tmp_fg_dep_ids)))
250-
if (g['mg_sister'])|(g['mg_parent']):
251-
fg_dep_ids.append(numpy.sort(numpy.array(g['mg_id'])))
252-
g['fg_dep_ids'] = fg_dep_ids
253-
else:
254-
g['fg_dep_ids'] = numpy.array([])
267+
global_dep_ids.append(numpy.array([subroot_nn, node.numerical_label]))
268+
return global_dep_ids
269+
270+
def get_foreground_dep_ids(g):
271+
fg_dep_ids = dict()
272+
for trait_name in g['fg_df'].columns[1:len(g['fg_df'].columns)]:
273+
if (g['foreground'] is not None)&(g['fg_exclude_wg']):
274+
fg_dep_ids[trait_name] = list()
275+
for i in numpy.arange(len(g['fg_leaf_names'][trait_name])):
276+
fg_lineage_leaf_names = g['fg_leaf_names'][trait_name][i]
277+
tmp_fg_dep_ids = list()
278+
for node in g['tree'].traverse():
279+
if node.is_root():
280+
continue
281+
is_all_leaf_lineage_fg = all([ ln in fg_lineage_leaf_names for ln in node.get_leaf_names() ])
282+
if not is_all_leaf_lineage_fg:
283+
continue
284+
is_up_all_leaf_lineage_fg = all([ ln in fg_lineage_leaf_names for ln in node.up.get_leaf_names() ])
285+
if is_up_all_leaf_lineage_fg:
286+
continue
287+
if node.is_leaf():
288+
tmp_fg_dep_ids.append(node.numerical_label)
289+
else:
290+
descendant_nn = [ n.numerical_label for n in node.get_descendants() ]
291+
tmp_fg_dep_ids += [node.numerical_label,] + descendant_nn
292+
if len(tmp_fg_dep_ids)>1:
293+
fg_dep_ids[trait_name].append(numpy.sort(numpy.array(tmp_fg_dep_ids)))
294+
if (g['mg_sister'])|(g['mg_parent']):
295+
fg_dep_ids[trait_name].append(numpy.sort(numpy.array(g['mg_ids'][trait_name])))
296+
else:
297+
fg_dep_ids[trait_name] = numpy.array([])
298+
return fg_dep_ids
299+
300+
def get_dep_ids(g):
301+
g['dep_ids'] = get_global_dep_ids(g)
302+
g['fg_dep_ids'] = get_foreground_dep_ids(g)
255303
return g

csubst/csubst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ if __name__ == "__main__":
114114
'The file should contain two columns separated by a tab: '
115115
'1st column for lineage IDs and 2nd for regex-compatible leaf names. '
116116
'See https://github.com/kfuku52/csubst/wiki/Foreground-specification')
117+
psr_fg.add_argument('--fg_format', metavar='1|2', default=1, type=int, choices=[1,2],
118+
help='default=%(default)s: Table format of --foreground.')
117119
psr_fg.add_argument('--fg_exclude_wg', metavar='yes|no', default='yes', type=strtobool,
118120
help='default=%(default)s: Set "yes" to exclude branch combinations '
119121
'within individual foreground lineages.')
@@ -283,6 +285,8 @@ if __name__ == "__main__":
283285
analyze.add_argument('--cbs', metavar='yes|no', default='no', type=strtobool,
284286
help='default=%(default)s: Combinatorial-branch-site output. Set "yes" to generate the output tsv.')
285287
# Plot outputs
288+
analyze.add_argument('--more_tree_plot', metavar='yes|no', default='no', type=strtobool,
289+
help='default=%(default)s: More tree plots to generate.')
286290
analyze.add_argument('--plot_state_aa', metavar='yes|no', default='no', type=strtobool,
287291
help='default=%(default)s: Tree plots with per-site ancestral amino acid states. '
288292
'This option will generate many pdfs')

0 commit comments

Comments
 (0)