@@ -19,8 +19,7 @@ def node_union(index_combinations, target_nodes, df_mmap, mmap_start):
1919 df_mmap [i , :] = node_union
2020 i += 1
2121
22- def nc_matrix2id_combinations (nc_matrix , arity , ncpu , verbose ):
23- start = time .time ()
22+ def nc_matrix2id_combinations (nc_matrix , arity , ncpu ):
2423 rows , cols = numpy .where (numpy .equal (nc_matrix , 1 ))
2524 unique_cols = numpy .unique (cols )
2625 ind2 = numpy .arange (arity , dtype = numpy .int64 )
@@ -31,12 +30,10 @@ def nc_matrix2id_combinations(nc_matrix, arity, ncpu, verbose):
3130 (chunk , start , arity , ind2 , rows , cols , unique_cols ) for chunk , start in zip (chunks , starts )
3231 )
3332 id_combinations = numpy .concatenate (out )
34- if verbose :
35- print ('Time elapsed for generating branch combinations: {:,} sec' .format (int (time .time () - start )))
3633 return id_combinations
3734
3835def get_node_combinations (g , target_nodes = None , arity = 2 , check_attr = None , verbose = True ):
39- g ['fg_dependent_id_combinations' ] = None
36+ g ['fg_dependent_id_combinations' ] = dict ()
4037 tree = g ['tree' ]
4138 all_nodes = [ node for node in tree .traverse () if not node .is_root () ]
4239 if verbose :
@@ -48,8 +45,37 @@ def get_node_combinations(g, target_nodes=None, arity=2, check_attr=None, verbos
4845 target_nodes .append (node .numerical_label )
4946 target_nodes = numpy .array (target_nodes )
5047 node_combinations = list (itertools .combinations (target_nodes , arity ))
51- node_combinations = [set (nc ) for nc in node_combinations ]
52- node_combinations = numpy .array ([list (nc ) for nc in node_combinations ])
48+ node_combinations = [ set (nc ) for nc in node_combinations ]
49+ node_combinations = numpy .array ([ list (nc ) for nc in node_combinations ])
50+ elif isinstance (target_nodes , dict ):
51+ trait_names = list (target_nodes .keys ())
52+ node_combination_dict = dict ()
53+ for trait_name in trait_names :
54+ if (target_nodes [trait_name ].shape .__len__ ()== 1 ):
55+ target_nodes [trait_name ] = numpy .expand_dims (target_nodes [trait_name ], axis = 1 )
56+ index_combinations = list (itertools .combinations (numpy .arange (target_nodes [trait_name ].shape [0 ]), 2 ))
57+ if len (index_combinations )== 0 :
58+ sys .stderr .write ('There is no target branch combination at K = {:,}.\n ' .format (arity ))
59+ id_combinations = numpy .zeros (shape = [0 ,arity ], dtype = numpy .int64 )
60+ return g , id_combinations
61+ if verbose :
62+ txt = 'Number of branch combinations before independency check for {}: {:,}'
63+ print (txt .format (trait_name , len (index_combinations )), flush = True )
64+ axis = (len (index_combinations ), arity )
65+ mmap_out = os .path .join (os .getcwd (), 'tmp.csubst.node_combinations.mmap' )
66+ if os .path .exists (mmap_out ): os .unlink (mmap_out )
67+ df_mmap = numpy .memmap (mmap_out , dtype = numpy .int32 , shape = axis , mode = 'w+' )
68+ chunks ,starts = parallel .get_chunks (index_combinations , g ['threads' ])
69+ joblib .Parallel (n_jobs = g ['threads' ], max_nbytes = None , backend = 'multiprocessing' )(
70+ joblib .delayed (node_union )
71+ (ids , target_nodes [trait_name ], df_mmap , ms ) for ids , ms in zip (chunks , starts )
72+ )
73+ is_valid_combination = (df_mmap .sum (axis = 1 )!= 0 )
74+ if (is_valid_combination .sum ()> 0 ):
75+ node_combination_dict [trait_name ] = numpy .unique (df_mmap [is_valid_combination ,:], axis = 0 )
76+ else :
77+ node_combination_dict [trait_name ] = numpy .zeros (shape = [0 ,arity ], dtype = numpy .int64 )
78+ node_combinations = numpy .unique (numpy .concatenate (list (node_combination_dict .values ()), axis = 0 ), axis = 0 )
5379 elif isinstance (target_nodes , numpy .ndarray ):
5480 if (target_nodes .shape .__len__ ()== 1 ):
5581 target_nodes = numpy .expand_dims (target_nodes , axis = 1 )
@@ -74,10 +100,10 @@ def get_node_combinations(g, target_nodes=None, arity=2, check_attr=None, verbos
74100 node_combinations = numpy .unique (df_mmap [is_valid_combination ,:], axis = 0 )
75101 else :
76102 node_combinations = numpy .zeros (shape = [0 ,arity ], dtype = numpy .int64 )
103+ else :
104+ raise Exception ('target_nodes must be either None, dict, or numpy.ndarray.' )
77105 if verbose :
78- num_target_node = numpy .unique (target_nodes .flatten ()).shape [0 ]
79- print ("Number of target branches: {:,}" .format (num_target_node ), flush = True )
80- print ("Number of independent/non-independent branch combinations: {:,}" .format (node_combinations .shape [0 ]), flush = True )
106+ print ("Number of all branch combinations before independency check: {:,}" .format (node_combinations .shape [0 ]), flush = True )
81107 nc_matrix = numpy .zeros (shape = (len (all_nodes ), node_combinations .shape [0 ]), dtype = bool )
82108 for i in numpy .arange (node_combinations .shape [0 ]):
83109 nc_matrix [node_combinations [i ,:],i ] = True
@@ -87,26 +113,39 @@ def get_node_combinations(g, target_nodes=None, arity=2, check_attr=None, verbos
87113 if verbose :
88114 print ('Number of non-independent branch combinations to be removed: {:,}' .format (is_dependent_col .sum ()), flush = True )
89115 nc_matrix = nc_matrix [:,~ is_dependent_col ]
90- if (g ['foreground' ] is not None )& (len (g ['fg_dep_ids' ]) > 0 ):
91- is_fg_dependent_col = False
92- for fg_dep_id in g ['fg_dep_ids' ]:
116+ id_combinations = numpy .zeros (shape = (0 ,arity ), dtype = numpy .int64 )
117+ start = time .time ()
118+ trait_names = g ['fg_df' ].columns [1 :len (g ['fg_df' ].columns )].tolist ()
119+ for trait_name in trait_names :
120+ if verbose :
121+ if isinstance (target_nodes , dict ):
122+ num_target_node = numpy .unique (target_nodes [trait_name ].flatten ()).shape [0 ]
123+ else :
124+ num_target_node = numpy .unique (target_nodes .flatten ()).shape [0 ]
125+ print ("Number of target branches: {:,}" .format (num_target_node ), flush = True )
126+ is_fg_dependent_col = numpy .zeros (shape = (nc_matrix .shape [1 ],), dtype = bool )
127+ for fg_dep_id in g ['fg_dep_ids' ][trait_name ]:
93128 is_fg_dependent_col |= (nc_matrix [fg_dep_id , :].sum (axis = 0 ) > 1 )
94129 if (g ['exhaustive_until' ]>= arity ):
95130 if verbose :
96- txt = 'Detected {:,} (out of {:,}) foreground branch combinations to be treated as non-foreground '
97- txt += '(e.g., parent-child pairs).'
98- print (txt .format (is_fg_dependent_col .sum (), is_fg_dependent_col .shape [0 ]), flush = True )
131+ txt = 'Number of non-independent foreground branch combinations to be non-foreground-marked for {}: {:,} / {:,}'
132+ print (txt .format (trait_name , is_fg_dependent_col .sum (), is_fg_dependent_col .shape [0 ]), flush = True )
99133 fg_dep_nc_matrix = numpy .copy (nc_matrix )
100134 fg_dep_nc_matrix [:,~ is_fg_dependent_col ] = False
101- g ['fg_dependent_id_combinations' ] = nc_matrix2id_combinations (fg_dep_nc_matrix , arity , g ['threads' ], verbose )
135+ g ['fg_dependent_id_combinations' ][trait_name ] = nc_matrix2id_combinations (fg_dep_nc_matrix , arity , g ['threads' ])
136+ if trait_name == trait_names [0 ]:
137+ id_combinations = nc_matrix2id_combinations (nc_matrix , arity , g ['threads' ])
102138 else :
103139 if verbose :
104- txt = 'removing {:,} (out of {:,}) dependent foreground branch combinations.'
105- print (txt .format (is_fg_dependent_col .sum (), is_fg_dependent_col .shape [0 ]), flush = True )
140+ txt = 'Removing {:,} (out of {:,}) non-independent foreground branch combinations for {} .'
141+ print (txt .format (is_fg_dependent_col .sum (), is_fg_dependent_col .shape [0 ], trait_name ), flush = True )
106142 nc_matrix = nc_matrix [:,~ is_fg_dependent_col ]
107- id_combinations = nc_matrix2id_combinations (nc_matrix , arity , g ['threads' ], verbose )
143+ g ['fg_dependent_id_combinations' ][trait_name ] = numpy .array ([])
144+ trait_id_combinations = nc_matrix2id_combinations (nc_matrix , arity , g ['threads' ])
145+ id_combinations = numpy .unique (numpy .concatenate ((id_combinations , trait_id_combinations ), axis = 0 ), axis = 0 )
108146 if verbose :
109- print ("Number of independent branch combinations: {:,}" .format (id_combinations .shape [0 ]), flush = True )
147+ print ('Time elapsed for generating branch combinations: {:,} sec' .format (int (time .time () - start )))
148+ print ("Number of independent branch combinations to be analyzed: {:,}" .format (id_combinations .shape [0 ]), flush = True )
110149 return g ,id_combinations
111150
112151def node_combination_subsamples_rifle (g , arity , rep ):
@@ -198,58 +237,67 @@ def calc_substitution_patterns(cb):
198237 cb .loc [:,key + '_pattern_id' ] = sub_patterns4 .loc [:,'sub_pattern_id' ]
199238 return cb
200239
201- def get_dep_ids (g ):
202- dep_ids = list ()
240+ def get_global_dep_ids (g ):
241+ global_dep_ids = list ()
203242 for leaf in g ['tree' ].iter_leaves ():
204- ancestor_nns = [ node .numerical_label for node in leaf .iter_ancestors () if not node .is_root () ]
205- dep_id = [leaf .numerical_label ,] + ancestor_nns
243+ ancestor_nns = [node .numerical_label for node in leaf .iter_ancestors () if not node .is_root ()]
244+ dep_id = [leaf .numerical_label , ] + ancestor_nns
206245 dep_id = numpy .sort (numpy .array (dep_id ))
207- dep_ids .append (dep_id )
208- if g ['exclude_sister_pair' ]:
209- for node in g ['tree' ].traverse ():
210- children = node .get_children ()
211- if len (children )> 1 :
212- dep_id = numpy .sort (numpy .array ([ node .numerical_label for node in children ]))
213- dep_ids .append (dep_id )
246+ global_dep_ids .append (dep_id )
247+ if g ['exclude_sister_pair' ]:
248+ for node in g ['tree' ].traverse ():
249+ children = node .get_children ()
250+ if len (children )> 1 :
251+ dep_id = numpy .sort (numpy .array ([ node .numerical_label for node in children ]))
252+ global_dep_ids .append (dep_id )
214253 root_nn = g ['tree' ].numerical_label
215- root_state_sum = g ['state_cdn' ][root_nn ,:, :].sum ()
216- if (root_state_sum == 0 ):
254+ root_state_sum = g ['state_cdn' ][root_nn , :, :].sum ()
255+ if (root_state_sum == 0 ):
217256 print ('Ancestral states were not estimated on the root node. Excluding sub-root nodes from the analysis.' )
218- subroot_nns = [ node .numerical_label for node in g ['tree' ].get_children () ]
257+ subroot_nns = [node .numerical_label for node in g ['tree' ].get_children ()]
219258 for subroot_nn in subroot_nns :
220259 for node in g ['tree' ].traverse ():
221260 if node .is_root ():
222261 continue
223- if subroot_nn == node .numerical_label :
262+ if subroot_nn == node .numerical_label :
224263 continue
225- ancestor_nns = [ anc .numerical_label for anc in node .iter_ancestors () ]
264+ ancestor_nns = [anc .numerical_label for anc in node .iter_ancestors ()]
226265 if subroot_nn in ancestor_nns :
227266 continue
228- dep_ids .append (numpy .array ([subroot_nn , node .numerical_label ]))
229- g ['dep_ids' ] = dep_ids
230- if (g ['foreground' ] is not None )& (g ['fg_exclude_wg' ]):
231- fg_dep_ids = list ()
232- for i in numpy .arange (len (g ['fg_leaf_name' ])):
233- tmp_fg_dep_ids = list ()
234- for node in g ['tree' ].traverse ():
235- if node .is_root ():
236- continue
237- is_all_leaf_lineage_fg = all ([ ln in g ['fg_leaf_name' ][i ] for ln in node .get_leaf_names () ])
238- if not is_all_leaf_lineage_fg :
239- continue
240- is_up_all_leaf_lineage_fg = all ([ ln in g ['fg_leaf_name' ][i ] for ln in node .up .get_leaf_names () ])
241- if is_up_all_leaf_lineage_fg :
242- continue
243- if node .is_leaf ():
244- tmp_fg_dep_ids .append (node .numerical_label )
245- else :
246- descendant_nn = [ n .numerical_label for n in node .get_descendants () ]
247- tmp_fg_dep_ids += [node .numerical_label ,] + descendant_nn
248- if len (tmp_fg_dep_ids )> 1 :
249- fg_dep_ids .append (numpy .sort (numpy .array (tmp_fg_dep_ids )))
250- if (g ['mg_sister' ])| (g ['mg_parent' ]):
251- fg_dep_ids .append (numpy .sort (numpy .array (g ['mg_id' ])))
252- g ['fg_dep_ids' ] = fg_dep_ids
253- else :
254- g ['fg_dep_ids' ] = numpy .array ([])
267+ global_dep_ids .append (numpy .array ([subroot_nn , node .numerical_label ]))
268+ return global_dep_ids
269+
270+ def get_foreground_dep_ids (g ):
271+ fg_dep_ids = dict ()
272+ for trait_name in g ['fg_df' ].columns [1 :len (g ['fg_df' ].columns )]:
273+ if (g ['foreground' ] is not None )& (g ['fg_exclude_wg' ]):
274+ fg_dep_ids [trait_name ] = list ()
275+ for i in numpy .arange (len (g ['fg_leaf_names' ][trait_name ])):
276+ fg_lineage_leaf_names = g ['fg_leaf_names' ][trait_name ][i ]
277+ tmp_fg_dep_ids = list ()
278+ for node in g ['tree' ].traverse ():
279+ if node .is_root ():
280+ continue
281+ is_all_leaf_lineage_fg = all ([ ln in fg_lineage_leaf_names for ln in node .get_leaf_names () ])
282+ if not is_all_leaf_lineage_fg :
283+ continue
284+ is_up_all_leaf_lineage_fg = all ([ ln in fg_lineage_leaf_names for ln in node .up .get_leaf_names () ])
285+ if is_up_all_leaf_lineage_fg :
286+ continue
287+ if node .is_leaf ():
288+ tmp_fg_dep_ids .append (node .numerical_label )
289+ else :
290+ descendant_nn = [ n .numerical_label for n in node .get_descendants () ]
291+ tmp_fg_dep_ids += [node .numerical_label ,] + descendant_nn
292+ if len (tmp_fg_dep_ids )> 1 :
293+ fg_dep_ids [trait_name ].append (numpy .sort (numpy .array (tmp_fg_dep_ids )))
294+ if (g ['mg_sister' ])| (g ['mg_parent' ]):
295+ fg_dep_ids [trait_name ].append (numpy .sort (numpy .array (g ['mg_ids' ][trait_name ])))
296+ else :
297+ fg_dep_ids [trait_name ] = numpy .array ([])
298+ return fg_dep_ids
299+
300+ def get_dep_ids (g ):
301+ g ['dep_ids' ] = get_global_dep_ids (g )
302+ g ['fg_dep_ids' ] = get_foreground_dep_ids (g )
255303 return g
0 commit comments