-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathfuncs.py
299 lines (219 loc) · 8.05 KB
/
funcs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
import os
import numpy as np
# Get helpers from exp directory
import sys
sys.path.append("../exp/")
from helpers import get_ensemble_options
from models import get_base_parcel_names
from datetime import timedelta
def get_multi_parcel_size(parcel, parc_dr):
# Get names of component parcels
base_parcels = get_base_parcel_names(parcel)
# Get size of each base parcel
base_parcel_sizes = [get_parcel_size(parc, parc_dr) for parc in base_parcels]
# If grid, take max
if parcel.startswith('grid_'):
return max(base_parcel_sizes)
# Otherwise, sum
return sum(base_parcel_sizes)
def get_parcel_size(parcel, parc_dr):
# Multiple parcellation cases
if parcel.startswith('stacked_') or parcel.startswith('voted_') or parcel.startswith('grid_'):
return get_multi_parcel_size(parcel, parc_dr)
# Random parcel case (can know size from name, no need to load)
if parcel.startswith('random_'):
return int(parcel.split('_')[1])
# Freesurfer cases
if parcel == 'freesurfer_destr':
return 150
if parcel == 'freesurfer_desikan':
return 68
# Otherwise determine size by loading and checking unique number of parcels
# Load parcel to check size
parc = np.load(os.path.join(parc_dr, parcel + '.npy'))
# If probabilistic
if len(parc.shape) == 2:
sz = parc.shape[1]
# If static, minus 1 for parc marking empty / 0
else:
sz = len(np.unique(parc)) - 1
return sz
def get_parc_sizes(parc_dr='../parcels',
base=False,
ico=False,
random=False,
fs=False,
stacked=False,
voted=False,
grid=False,
add_special=False,
everything=False,
size_min=None,
size_max=None):
if everything:
base = True
ico = True
random = True
fs = True
stacked = True
voted = True
grid = True
add_special = True
# First fill list with all options based on passed options
parcels = []
# Base parcels are either random, ico or base existing
all_parcels = [parc.replace('.npy', '') for parc in os.listdir(parc_dr)]
for p in all_parcels:
if p.startswith('random_'):
if random:
parcels.append(p)
elif p.startswith('icosahedron'):
if ico:
parcels.append(p)
else:
if base:
parcels.append(p)
# If extra freesurfer requested
if fs:
parcels.append('freesurfer_destr')
parcels.append('freesurfer_desikan')
# Check for extra multiple parcellations
if stacked:
parcels += get_ensemble_options('stacked', add_special=add_special)
if voted:
parcels += get_ensemble_options('voted', add_special=add_special)
if grid:
parcels += get_ensemble_options('grid', add_special=add_special)
# Fill in parcel sizes
parc_sizes = {}
for parcel in parcels:
parc_sizes[parcel] = get_parcel_size(parcel, parc_dr)
# Apply any passed size restrictions
keys = list(parc_sizes)
for key in keys:
d = False
# Enforce size min or max
if size_min is not None:
if parc_sizes[key] < size_min:
d = True
if size_max is not None:
if parc_sizes[key] > size_max:
d = True
# If delete flag remove
if d:
del parc_sizes[key]
return parc_sizes
def conv_t_delta(in_str):
in_str = in_str.replace('Time Elapsed: ', '').strip()
s = in_str.split(':')
if len(s) == 3:
return timedelta(hours=int(s[0]),
minutes=int(s[1]),
seconds=int(s[2]))
else:
print('Error with:', s)
def is_binary(target):
target = target.rstrip()
if target.endswith('_binary'):
return True
# Otherwise check exceptions
binary = ['ksads_back_c_det_susp_p', 'married.bl',
'accult_phenx_q2_p', 'devhx_5_twin_p',
'sex_at_birth', 'devhx_6_pregnancy_planned_p',
'devhx_12a_born_premature_p',
'ksads_back_c_mh_sa_p']
return target in binary
def get_p_type(parcel):
if parcel.startswith('stacked_'):
return 'stacked'
elif parcel.startswith('voted'):
return 'voted'
elif parcel.startswith('grid'):
return 'grid'
return 'base'
def get_time_elapsed(txt, n_strip):
te_ind = 'Time Elapsed: '
time_elapsed = [conv_t_delta(l) for l in txt if l.startswith(te_ind)]
full = None
if len(n_strip) == 4:
if len(time_elapsed) == 1:
full = time_elapsed[0] * 5
else:
if len(time_elapsed) == 5:
full = np.sum(time_elapsed)
return full
def get_n_jobs(txt):
n_jobs_line = txt[['n_jobs = ' in line for line in txt].index(True)]
n_jobs = n_jobs_line.replace('n_jobs = ', '').rstrip()
return int(n_jobs)
def extract_run_info(txt):
run_ind = 'Running for:'
ind = [run_ind in line for line in txt].index(True)
return txt[ind].replace(run_ind, '').strip()
def extract(txt, parc_sizes, skip_svm=False):
# If not a finished run
if not 'Validation Scores\n' in txt:
return None
# Get base run info
n_strip = extract_run_info(txt).split('---')
parcel, model = n_strip[0], n_strip[1]
is_b = is_binary(n_strip[2])
# Skip SVM if skip SVM
if skip_svm and model == 'SVM':
return None
# Set p_type by if ensemble
p_type = get_p_type(parcel)
# Get parcel size, if invalid skip
try:
size = parc_sizes[parcel]
except KeyError:
return None
# Get time elapsed
full = get_time_elapsed(txt, n_strip)
# If not valid number of times
# means didn't fully finish and skip this run
if full is None:
return None
# Convert to seconds
secs = full.total_seconds()
# Extract n_jobs
n_jobs = get_n_jobs(txt)
# Get number of load saved
n_load_saved = sum(['Loading from saved!' in line for line in txt])
return model, size, is_b, secs, p_type, n_jobs, n_load_saved
def save_stats_summary(model, name):
# Save html stats table
# Trunc second half of first table
t1 = model.summary().tables[0].as_html()
t1 = t1[:t1.index('<tr>\n <th>Time:</th>')] + '</table>'
# Replace headers
t_header = '<table class="simpletable">'
t_header_new = '<table class="simpletable" style="margin-left: auto; margin-right: auto; width: 75%" align="center">'
t1 = t1.replace(t_header, t_header_new)
t2 = model.summary().tables[1].as_html().replace(t_header, t_header_new)
t2 = t2.replace('width: 75%', 'width: 95%')
html = '<html><body>' + t1 + '<br>'
html += t2 + '</body></html>'
with open('../docs/_includes/' + name + '.html', 'w') as f:
f.write(html)
def clean_col_names(r_df):
return r_df.rename({'Mean_Rank': 'Mean Rank', 'Median_Rank': 'Median Rank', 'r2': 'Mean R2',
'roc_auc': 'Mean ROC AUC', 'target': 'Target', 'Model': 'Pipeline',
'full_name': 'Parcellation', 'Mean_Score': 'Mean Score'}, axis=1)
def save_results_table(r_df, name):
r_df = clean_col_names(r_df)
order = ['Parcellation', 'Mean Rank', 'Size',
'Mean R2', 'Mean ROC AUC']
if 'Median Rank' in list(r_df):
order += ['Median Rank']
# Sort columns and rows
r_df = r_df[order]
r_df = r_df.sort_values('Mean Rank')
# Save as html
save_table(r_df, name)
def save_table(r_df, name):
html = '<script src="https://www.kryogenix.org/code/browser/sorttable/sorttable.js"></script>'
html += r_df.to_html(float_format="%.4f", classes=['sortable'], border=0,
index=False, justify='center')
with open('../docs/_includes/' + name + '.html', 'w') as f:
f.write(html)