Skip to content

Commit 0e42615

Browse files
authored
added gpu support and corrected the roc_auc metric
1 parent 4b096a0 commit 0e42615

File tree

1 file changed

+64
-32
lines changed

1 file changed

+64
-32
lines changed

trainer.py

+64-32
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,48 @@
1+
import matplotlib
2+
matplotlib.use('Agg')
13
import matplotlib.pyplot as plt
4+
25
import sys, os, time, random, pdb
36
import numpy as np
47
import pandas as pd
58
import torch.nn.functional as F
69
import torch
710
import pickle
8-
import tqdm
11+
import tqdm, pdb
912
from sklearn.metrics import roc_auc_score
1013

1114
import config
1215

13-
def get_roc_auc_score(y_true, y_probs, average = 'micro'):
16+
def get_roc_auc_score(y_true, y_probs):
1417
'''
1518
Uses roc_auc_score function from sklearn.metrics to calculate the micro ROC AUC score for a given y_true and y_probs.
1619
'''
17-
return roc_auc_score(y_true, y_probs, average = average)
20+
21+
with open(os.path.join(config.pkl_dir_path, config.disease_classes_pkl_path), 'rb') as handle:
22+
all_classes = pickle.load(handle)
23+
24+
NoFindingIndex = all_classes.index('No Finding')
25+
26+
if True:
27+
print('\nNoFindingIndex: ', NoFindingIndex)
28+
print('y_true.shape, y_probs.shape ', y_true.shape, y_probs.shape)
29+
GT_and_probs = {'y_true': y_true, 'y_probs': y_probs}
30+
with open('GT_and_probs', 'wb') as handle:
31+
pickle.dump(GT_and_probs, handle, protocol = pickle.HIGHEST_PROTOCOL)
32+
33+
class_roc_auc_list = []
34+
useful_classes_roc_auc_list = []
35+
36+
for i in range(y_true.shape[1]):
37+
class_roc_auc = roc_auc_score(y_true[:, i], y_probs[:, i])
38+
class_roc_auc_list.append(class_roc_auc)
39+
if i != NoFindingIndex:
40+
useful_classes_roc_auc_list.append(class_roc_auc)
41+
if True:
42+
print('\nclass_roc_auc_list: ', class_roc_auc_list)
43+
print('\nuseful_classes_roc_auc_list', useful_classes_roc_auc_list)
44+
45+
return np.mean(np.array(useful_classes_roc_auc_list))
1846

1947
def make_plot(epoch_train_loss, epoch_val_loss, total_train_loss_list, total_val_loss_list, save_name):
2048
'''
@@ -78,25 +106,25 @@ def get_resampled_train_val_dataloaders(XRayTrain_dataset, transform, bs):
78106

79107
return train_loader, val_loader
80108

81-
def train_epoch(train_loader, model, loss_fn, optimizer, step_lr_scheduler, epochs_till_now, final_epoch, log_interval):
109+
def train_epoch(device, train_loader, model, loss_fn, optimizer, epochs_till_now, final_epoch, log_interval):
82110
'''
83111
Takes in the data from the 'train_loader', calculates the loss over it using the 'loss_fn'
84112
and optimizes the 'model' using the 'optimizer'
85113
86114
Also prints the loss and the ROC AUC score for the batches, after every 'log_interval' batches.
87115
'''
88-
step_lr_scheduler.step() # the lr of the optimizer is multiplied with 'gamma' on the 'step_size'th time step() is called on step_lr_scheduler
89-
# if initial lr of the optimized is 0.001 and step_lr_scheduler has step_size = 2 and gamma = 0.8, on the 2nd call of step_lr_scheduler.step(), optimizer's lr becomes 0.001*gamma
90116
model.train()
91117

92118
running_train_loss = 0
93119
train_loss_list = []
94120

121+
start_time = time.time()
95122
for batch_idx, (img, target) in enumerate(train_loader):
96123
# print(type(img), img.shape) # , np.unique(img))
97-
98-
start_time = time.time()
99124

125+
img = img.to(device)
126+
target = target.to(device)
127+
100128
optimizer.zero_grad()
101129
out = model(img)
102130
loss = loss_fn(out, target)
@@ -108,19 +136,21 @@ def train_epoch(train_loader, model, loss_fn, optimizer, step_lr_scheduler, epoc
108136

109137
if (batch_idx+1)%log_interval == 0:
110138
# batch metric evaluation
111-
out_detached = out.detach()
112-
batch_roc_auc_score = get_roc_auc_score(target, out_detached.numpy())
139+
# # out_detached = out.detach()
140+
# # batch_roc_auc_score = get_roc_auc_score(target, out_detached.numpy())
113141
# 'out' is a torch.Tensor and 'roc_auc_score' function first tries to convert it into a numpy array, but since 'out' has requires_grad = True, it throws an error
114142
# RuntimeError: Can't call numpy() on Variable that requires grad. Use var.detach().numpy() instead.
115143
# so we have to 'detach' the 'out' tensor and then convert it into a numpy array to avoid the error !
116144

117145
batch_time = time.time() - start_time
118146
m, s = divmod(batch_time, 60)
119-
print('Train Loss for batch {}/{} @epoch{}/{}: {} and batch_roc_auc_score: {} in {} mins {} secs'.format(str(batch_idx+1).zfill(3), str(len(train_loader)).zfill(3), epochs_till_now, final_epoch, round(loss.item(), 5), round(batch_roc_auc_score, 5), int(m), int(s)))
147+
print('Train Loss for batch {}/{} @epoch{}/{}: {} in {} mins {} secs'.format(str(batch_idx+1).zfill(3), str(len(train_loader)).zfill(3), epochs_till_now, final_epoch, round(loss.item(), 5), int(m), round(s, 2)))
148+
149+
start_time = time.time()
120150

121151
return train_loss_list, running_train_loss/float(len(train_loader.dataset))
122152

123-
def val_epoch(val_loader, model, loss_fn, epochs_till_now = None, final_epoch = None, log_interval = 1, test_only = False):
153+
def val_epoch(device, val_loader, model, loss_fn, epochs_till_now = None, final_epoch = None, log_interval = 1, test_only = False):
124154
'''
125155
It essentially takes in the val_loader/test_loader, the model and the loss function and evaluates
126156
the loss and the ROC AUC score for all the data in the dataloader.
@@ -138,39 +168,47 @@ def val_epoch(val_loader, model, loss_fn, epochs_till_now = None, final_epoch =
138168
k=0
139169

140170
with torch.no_grad():
171+
batch_start_time = time.time()
141172
for batch_idx, (img, target) in enumerate(val_loader):
173+
if test_only:
174+
per = ((batch_idx+1)/len(val_loader))*100
175+
a_, b_ = divmod(per, 1)
176+
print(f'{str(batch_idx+1).zfill(len(str(len(val_loader))))}/{str(len(val_loader)).zfill(len(str(len(val_loader))))} ({str(int(a_)).zfill(2)}.{str(int(100*b_)).zfill(2)} %)', end = '\r')
142177
# print(type(img), img.shape) # , np.unique(img))
143178

144-
batch_start_time = time.time()
145-
179+
img = img.to(device)
180+
target = target.to(device)
181+
146182
out = model(img)
147183
loss = loss_fn(out, target)
148184
running_val_loss += loss.item()*img.shape[0]
149185
val_loss_list.append(loss.item())
150186

151187
# storing model predictions for metric evaluat`ion
152-
probs[k: k + out.shape[0], :] = out
153-
gt[ k: k + out.shape[0], :] = target
188+
probs[k: k + out.shape[0], :] = out.cpu()
189+
gt[ k: k + out.shape[0], :] = target.cpu()
154190
k += out.shape[0]
155191

156192
if ((batch_idx+1)%log_interval == 0) and (not test_only): # only when ((batch_idx + 1) is divisible by log_interval) and (when test_only = False)
157193
# batch metric evaluation
158-
batch_roc_auc_score = get_roc_auc_score(target, out)
194+
# batch_roc_auc_score = get_roc_auc_score(target, out)
159195

160196
batch_time = time.time() - batch_start_time
161197
m, s = divmod(batch_time, 60)
162-
print('Val Loss for batch {}/{} @epoch{}/{}: {} and batch_roc_auc_score: {} in {} mins {} secs'.format(str(batch_idx+1).zfill(3), str(len(val_loader)).zfill(3), epochs_till_now, final_epoch, round(loss.item(), 5), round(batch_roc_auc_score, 5), int(m), int(s)))
163-
198+
print('Val Loss for batch {}/{} @epoch{}/{}: {} in {} mins {} secs'.format(str(batch_idx+1).zfill(3), str(len(val_loader)).zfill(3), epochs_till_now, final_epoch, round(loss.item(), 5), int(m), round(s, 2)))
199+
200+
batch_start_time = time.time()
201+
164202
# metric scenes
165203
roc_auc = get_roc_auc_score(gt, probs)
166204

167205
return val_loss_list, running_val_loss/float(len(val_loader.dataset)), roc_auc
168206

169-
def fit(XRayTrain_dataset, train_loader, val_loader, test_loader, model,
170-
loss_fn, optimizer, lr_scheduler, losses_dict,
207+
def fit(device, XRayTrain_dataset, train_loader, val_loader, test_loader, model,
208+
loss_fn, optimizer, losses_dict,
171209
epochs_till_now, epochs,
172210
log_interval, save_interval,
173-
lr, bs, stage_num, test_only = False):
211+
lr, bs, stage, test_only = False):
174212
'''
175213
Trains or Tests the 'model' on the given 'train_loader', 'val_loader', 'test_loader' for 'epochs' number of epochs.
176214
If training ('test_only' = False), it saves the optimized 'model' and the loss plots ,after every 'save_interval'th epoch.
@@ -182,7 +220,7 @@ def fit(XRayTrain_dataset, train_loader, val_loader, test_loader, model,
182220
if test_only:
183221
print('\n======= Testing... =======\n')
184222
test_start_time = time.time()
185-
test_loss, mean_running_test_loss, test_roc_auc = val_epoch(test_loader, model, loss_fn, log_interval, test_only = test_only)
223+
test_loss, mean_running_test_loss, test_roc_auc = val_epoch(device, test_loader, model, loss_fn, log_interval, test_only = test_only)
186224
total_test_time = time.time() - test_start_time
187225
m, s = divmod(total_test_time, 60)
188226
print('test_roc_auc: {} in {} mins {} secs'.format(test_roc_auc, int(m), int(s)))
@@ -208,17 +246,17 @@ def fit(XRayTrain_dataset, train_loader, val_loader, test_loader, model,
208246
epoch_start_time = time.time()
209247

210248
print('TRAINING')
211-
train_loss, mean_running_train_loss = train_epoch(train_loader, model, loss_fn, optimizer, lr_scheduler, epochs_till_now, final_epoch, log_interval)
249+
train_loss, mean_running_train_loss = train_epoch(device, train_loader, model, loss_fn, optimizer, epochs_till_now, final_epoch, log_interval)
212250
print('VALIDATION')
213-
val_loss, mean_running_val_loss, roc_auc = val_epoch(val_loader, model, loss_fn , epochs_till_now, final_epoch, log_interval)
251+
val_loss, mean_running_val_loss, roc_auc = val_epoch(device, val_loader, model, loss_fn , epochs_till_now, final_epoch, log_interval)
214252

215253
epoch_train_loss.append(mean_running_train_loss)
216254
epoch_val_loss.append(mean_running_val_loss)
217255

218256
total_train_loss_list.extend(train_loss)
219257
total_val_loss_list.extend(val_loss)
220258

221-
save_name = 'stage{}_{}_{}'.format(stage_num, str.split(str(lr), '.')[-1], epochs_till_now)
259+
save_name = 'stage{}_{}_{}'.format(stage, str.split(str(lr), '.')[-1], str(epochs_till_now).zfill(2))
222260

223261
# the follwoing piece of codw needs to be worked on !!! LATEST DEVELOPMENT TILL HERE
224262
if ((epoch+1)%save_interval == 0) or test_only:
@@ -227,7 +265,6 @@ def fit(XRayTrain_dataset, train_loader, val_loader, test_loader, model,
227265
torch.save({
228266
'epochs': epochs_till_now,
229267
'model': model, # it saves the whole model
230-
'lr_scheduler_state_dict': lr_scheduler.state_dict(), # dict_keys(['step_size', 'gamma', 'base_lrs', 'last_epoch'])
231268
'losses_dict': {'epoch_train_loss': epoch_train_loss, 'epoch_val_loss': epoch_val_loss, 'total_train_loss_list': total_train_loss_list, 'total_val_loss_list': total_val_loss_list}
232269
}, save_path)
233270

@@ -247,9 +284,6 @@ def fit(XRayTrain_dataset, train_loader, val_loader, test_loader, model,
247284

248285

249286

250-
251-
252-
253287
'''
254288
def pred_n_write(test_loader, model, save_name):
255289
res = np.zeros((3000, 15), dtype = np.float32)
@@ -266,7 +300,6 @@ def pred_n_write(test_loader, model, save_name):
266300
print('populating the csv')
267301
submit = pd.DataFrame()
268302
submit['ImageID'] = [str.split(i, os.sep)[-1] for i in test_loader.dataset.data_list]
269-
270303
with open('disease_classes.pickle', 'rb') as handle:
271304
disease_classes = pickle.load(handle)
272305
@@ -279,7 +312,6 @@ def pred_n_write(test_loader, model, save_name):
279312
submit['No_findings'] = res[:, idx]
280313
else:
281314
submit[col] = res[:, idx]
282-
283315
rand_num = str(random.randint(1000, 9999))
284316
csv_name = '{}___{}.csv'.format(save_name, rand_num)
285317
submit.to_csv('res/' + csv_name, index = False)

0 commit comments

Comments
 (0)