1
+ import matplotlib
2
+ matplotlib .use ('Agg' )
1
3
import matplotlib .pyplot as plt
4
+
2
5
import sys , os , time , random , pdb
3
6
import numpy as np
4
7
import pandas as pd
5
8
import torch .nn .functional as F
6
9
import torch
7
10
import pickle
8
- import tqdm
11
+ import tqdm , pdb
9
12
from sklearn .metrics import roc_auc_score
10
13
11
14
import config
12
15
13
- def get_roc_auc_score (y_true , y_probs , average = 'micro' ):
16
+ def get_roc_auc_score (y_true , y_probs ):
14
17
'''
15
18
Uses roc_auc_score function from sklearn.metrics to calculate the micro ROC AUC score for a given y_true and y_probs.
16
19
'''
17
- return roc_auc_score (y_true , y_probs , average = average )
20
+
21
+ with open (os .path .join (config .pkl_dir_path , config .disease_classes_pkl_path ), 'rb' ) as handle :
22
+ all_classes = pickle .load (handle )
23
+
24
+ NoFindingIndex = all_classes .index ('No Finding' )
25
+
26
+ if True :
27
+ print ('\n NoFindingIndex: ' , NoFindingIndex )
28
+ print ('y_true.shape, y_probs.shape ' , y_true .shape , y_probs .shape )
29
+ GT_and_probs = {'y_true' : y_true , 'y_probs' : y_probs }
30
+ with open ('GT_and_probs' , 'wb' ) as handle :
31
+ pickle .dump (GT_and_probs , handle , protocol = pickle .HIGHEST_PROTOCOL )
32
+
33
+ class_roc_auc_list = []
34
+ useful_classes_roc_auc_list = []
35
+
36
+ for i in range (y_true .shape [1 ]):
37
+ class_roc_auc = roc_auc_score (y_true [:, i ], y_probs [:, i ])
38
+ class_roc_auc_list .append (class_roc_auc )
39
+ if i != NoFindingIndex :
40
+ useful_classes_roc_auc_list .append (class_roc_auc )
41
+ if True :
42
+ print ('\n class_roc_auc_list: ' , class_roc_auc_list )
43
+ print ('\n useful_classes_roc_auc_list' , useful_classes_roc_auc_list )
44
+
45
+ return np .mean (np .array (useful_classes_roc_auc_list ))
18
46
19
47
def make_plot (epoch_train_loss , epoch_val_loss , total_train_loss_list , total_val_loss_list , save_name ):
20
48
'''
@@ -78,25 +106,25 @@ def get_resampled_train_val_dataloaders(XRayTrain_dataset, transform, bs):
78
106
79
107
return train_loader , val_loader
80
108
81
- def train_epoch (train_loader , model , loss_fn , optimizer , step_lr_scheduler , epochs_till_now , final_epoch , log_interval ):
109
+ def train_epoch (device , train_loader , model , loss_fn , optimizer , epochs_till_now , final_epoch , log_interval ):
82
110
'''
83
111
Takes in the data from the 'train_loader', calculates the loss over it using the 'loss_fn'
84
112
and optimizes the 'model' using the 'optimizer'
85
113
86
114
Also prints the loss and the ROC AUC score for the batches, after every 'log_interval' batches.
87
115
'''
88
- step_lr_scheduler .step () # the lr of the optimizer is multiplied with 'gamma' on the 'step_size'th time step() is called on step_lr_scheduler
89
- # if initial lr of the optimized is 0.001 and step_lr_scheduler has step_size = 2 and gamma = 0.8, on the 2nd call of step_lr_scheduler.step(), optimizer's lr becomes 0.001*gamma
90
116
model .train ()
91
117
92
118
running_train_loss = 0
93
119
train_loss_list = []
94
120
121
+ start_time = time .time ()
95
122
for batch_idx , (img , target ) in enumerate (train_loader ):
96
123
# print(type(img), img.shape) # , np.unique(img))
97
-
98
- start_time = time .time ()
99
124
125
+ img = img .to (device )
126
+ target = target .to (device )
127
+
100
128
optimizer .zero_grad ()
101
129
out = model (img )
102
130
loss = loss_fn (out , target )
@@ -108,19 +136,21 @@ def train_epoch(train_loader, model, loss_fn, optimizer, step_lr_scheduler, epoc
108
136
109
137
if (batch_idx + 1 )% log_interval == 0 :
110
138
# batch metric evaluation
111
- out_detached = out .detach ()
112
- batch_roc_auc_score = get_roc_auc_score (target , out_detached .numpy ())
139
+ # # out_detached = out.detach()
140
+ # # batch_roc_auc_score = get_roc_auc_score(target, out_detached.numpy())
113
141
# 'out' is a torch.Tensor and 'roc_auc_score' function first tries to convert it into a numpy array, but since 'out' has requires_grad = True, it throws an error
114
142
# RuntimeError: Can't call numpy() on Variable that requires grad. Use var.detach().numpy() instead.
115
143
# so we have to 'detach' the 'out' tensor and then convert it into a numpy array to avoid the error !
116
144
117
145
batch_time = time .time () - start_time
118
146
m , s = divmod (batch_time , 60 )
119
- print ('Train Loss for batch {}/{} @epoch{}/{}: {} and batch_roc_auc_score: {} in {} mins {} secs' .format (str (batch_idx + 1 ).zfill (3 ), str (len (train_loader )).zfill (3 ), epochs_till_now , final_epoch , round (loss .item (), 5 ), round (batch_roc_auc_score , 5 ), int (m ), int (s )))
147
+ print ('Train Loss for batch {}/{} @epoch{}/{}: {} in {} mins {} secs' .format (str (batch_idx + 1 ).zfill (3 ), str (len (train_loader )).zfill (3 ), epochs_till_now , final_epoch , round (loss .item (), 5 ), int (m ), round (s , 2 )))
148
+
149
+ start_time = time .time ()
120
150
121
151
return train_loss_list , running_train_loss / float (len (train_loader .dataset ))
122
152
123
- def val_epoch (val_loader , model , loss_fn , epochs_till_now = None , final_epoch = None , log_interval = 1 , test_only = False ):
153
+ def val_epoch (device , val_loader , model , loss_fn , epochs_till_now = None , final_epoch = None , log_interval = 1 , test_only = False ):
124
154
'''
125
155
It essentially takes in the val_loader/test_loader, the model and the loss function and evaluates
126
156
the loss and the ROC AUC score for all the data in the dataloader.
@@ -138,39 +168,47 @@ def val_epoch(val_loader, model, loss_fn, epochs_till_now = None, final_epoch =
138
168
k = 0
139
169
140
170
with torch .no_grad ():
171
+ batch_start_time = time .time ()
141
172
for batch_idx , (img , target ) in enumerate (val_loader ):
173
+ if test_only :
174
+ per = ((batch_idx + 1 )/ len (val_loader ))* 100
175
+ a_ , b_ = divmod (per , 1 )
176
+ print (f'{ str (batch_idx + 1 ).zfill (len (str (len (val_loader ))))} /{ str (len (val_loader )).zfill (len (str (len (val_loader ))))} ({ str (int (a_ )).zfill (2 )} .{ str (int (100 * b_ )).zfill (2 )} %)' , end = '\r ' )
142
177
# print(type(img), img.shape) # , np.unique(img))
143
178
144
- batch_start_time = time .time ()
145
-
179
+ img = img .to (device )
180
+ target = target .to (device )
181
+
146
182
out = model (img )
147
183
loss = loss_fn (out , target )
148
184
running_val_loss += loss .item ()* img .shape [0 ]
149
185
val_loss_list .append (loss .item ())
150
186
151
187
# storing model predictions for metric evaluat`ion
152
- probs [k : k + out .shape [0 ], :] = out
153
- gt [ k : k + out .shape [0 ], :] = target
188
+ probs [k : k + out .shape [0 ], :] = out . cpu ()
189
+ gt [ k : k + out .shape [0 ], :] = target . cpu ()
154
190
k += out .shape [0 ]
155
191
156
192
if ((batch_idx + 1 )% log_interval == 0 ) and (not test_only ): # only when ((batch_idx + 1) is divisible by log_interval) and (when test_only = False)
157
193
# batch metric evaluation
158
- batch_roc_auc_score = get_roc_auc_score (target , out )
194
+ # batch_roc_auc_score = get_roc_auc_score(target, out)
159
195
160
196
batch_time = time .time () - batch_start_time
161
197
m , s = divmod (batch_time , 60 )
162
- print ('Val Loss for batch {}/{} @epoch{}/{}: {} and batch_roc_auc_score: {} in {} mins {} secs' .format (str (batch_idx + 1 ).zfill (3 ), str (len (val_loader )).zfill (3 ), epochs_till_now , final_epoch , round (loss .item (), 5 ), round (batch_roc_auc_score , 5 ), int (m ), int (s )))
163
-
198
+ print ('Val Loss for batch {}/{} @epoch{}/{}: {} in {} mins {} secs' .format (str (batch_idx + 1 ).zfill (3 ), str (len (val_loader )).zfill (3 ), epochs_till_now , final_epoch , round (loss .item (), 5 ), int (m ), round (s , 2 )))
199
+
200
+ batch_start_time = time .time ()
201
+
164
202
# metric scenes
165
203
roc_auc = get_roc_auc_score (gt , probs )
166
204
167
205
return val_loss_list , running_val_loss / float (len (val_loader .dataset )), roc_auc
168
206
169
- def fit (XRayTrain_dataset , train_loader , val_loader , test_loader , model ,
170
- loss_fn , optimizer , lr_scheduler , losses_dict ,
207
+ def fit (device , XRayTrain_dataset , train_loader , val_loader , test_loader , model ,
208
+ loss_fn , optimizer , losses_dict ,
171
209
epochs_till_now , epochs ,
172
210
log_interval , save_interval ,
173
- lr , bs , stage_num , test_only = False ):
211
+ lr , bs , stage , test_only = False ):
174
212
'''
175
213
Trains or Tests the 'model' on the given 'train_loader', 'val_loader', 'test_loader' for 'epochs' number of epochs.
176
214
If training ('test_only' = False), it saves the optimized 'model' and the loss plots ,after every 'save_interval'th epoch.
@@ -182,7 +220,7 @@ def fit(XRayTrain_dataset, train_loader, val_loader, test_loader, model,
182
220
if test_only :
183
221
print ('\n ======= Testing... =======\n ' )
184
222
test_start_time = time .time ()
185
- test_loss , mean_running_test_loss , test_roc_auc = val_epoch (test_loader , model , loss_fn , log_interval , test_only = test_only )
223
+ test_loss , mean_running_test_loss , test_roc_auc = val_epoch (device , test_loader , model , loss_fn , log_interval , test_only = test_only )
186
224
total_test_time = time .time () - test_start_time
187
225
m , s = divmod (total_test_time , 60 )
188
226
print ('test_roc_auc: {} in {} mins {} secs' .format (test_roc_auc , int (m ), int (s )))
@@ -208,17 +246,17 @@ def fit(XRayTrain_dataset, train_loader, val_loader, test_loader, model,
208
246
epoch_start_time = time .time ()
209
247
210
248
print ('TRAINING' )
211
- train_loss , mean_running_train_loss = train_epoch (train_loader , model , loss_fn , optimizer , lr_scheduler , epochs_till_now , final_epoch , log_interval )
249
+ train_loss , mean_running_train_loss = train_epoch (device , train_loader , model , loss_fn , optimizer , epochs_till_now , final_epoch , log_interval )
212
250
print ('VALIDATION' )
213
- val_loss , mean_running_val_loss , roc_auc = val_epoch (val_loader , model , loss_fn , epochs_till_now , final_epoch , log_interval )
251
+ val_loss , mean_running_val_loss , roc_auc = val_epoch (device , val_loader , model , loss_fn , epochs_till_now , final_epoch , log_interval )
214
252
215
253
epoch_train_loss .append (mean_running_train_loss )
216
254
epoch_val_loss .append (mean_running_val_loss )
217
255
218
256
total_train_loss_list .extend (train_loss )
219
257
total_val_loss_list .extend (val_loss )
220
258
221
- save_name = 'stage{}_{}_{}' .format (stage_num , str .split (str (lr ), '.' )[- 1 ], epochs_till_now )
259
+ save_name = 'stage{}_{}_{}' .format (stage , str .split (str (lr ), '.' )[- 1 ], str ( epochs_till_now ). zfill ( 2 ) )
222
260
223
261
# the follwoing piece of codw needs to be worked on !!! LATEST DEVELOPMENT TILL HERE
224
262
if ((epoch + 1 )% save_interval == 0 ) or test_only :
@@ -227,7 +265,6 @@ def fit(XRayTrain_dataset, train_loader, val_loader, test_loader, model,
227
265
torch .save ({
228
266
'epochs' : epochs_till_now ,
229
267
'model' : model , # it saves the whole model
230
- 'lr_scheduler_state_dict' : lr_scheduler .state_dict (), # dict_keys(['step_size', 'gamma', 'base_lrs', 'last_epoch'])
231
268
'losses_dict' : {'epoch_train_loss' : epoch_train_loss , 'epoch_val_loss' : epoch_val_loss , 'total_train_loss_list' : total_train_loss_list , 'total_val_loss_list' : total_val_loss_list }
232
269
}, save_path )
233
270
@@ -247,9 +284,6 @@ def fit(XRayTrain_dataset, train_loader, val_loader, test_loader, model,
247
284
248
285
249
286
250
-
251
-
252
-
253
287
'''
254
288
def pred_n_write(test_loader, model, save_name):
255
289
res = np.zeros((3000, 15), dtype = np.float32)
@@ -266,7 +300,6 @@ def pred_n_write(test_loader, model, save_name):
266
300
print('populating the csv')
267
301
submit = pd.DataFrame()
268
302
submit['ImageID'] = [str.split(i, os.sep)[-1] for i in test_loader.dataset.data_list]
269
-
270
303
with open('disease_classes.pickle', 'rb') as handle:
271
304
disease_classes = pickle.load(handle)
272
305
@@ -279,7 +312,6 @@ def pred_n_write(test_loader, model, save_name):
279
312
submit['No_findings'] = res[:, idx]
280
313
else:
281
314
submit[col] = res[:, idx]
282
-
283
315
rand_num = str(random.randint(1000, 9999))
284
316
csv_name = '{}___{}.csv'.format(save_name, rand_num)
285
317
submit.to_csv('res/' + csv_name, index = False)
0 commit comments