Skip to content

Commit 4b096a0

Browse files
authored
added gpu support and stage parameter in argparser
1 parent 3704fd3 commit 4b096a0

File tree

1 file changed

+41
-36
lines changed

1 file changed

+41
-36
lines changed

main.py

+41-36
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
import pandas as pd
55
from tqdm import tqdm
66
import cv2
7-
# %matplotlib inline
8-
import matplotlib.pyplot as plt
97

108
import torch
119
import torch.nn as nn
@@ -17,7 +15,6 @@
1715

1816
# import neccesary libraries for defining the optimizers
1917
import torch.optim as optim
20-
from torch.optim import lr_scheduler
2118

2219
from trainer import fit
2320
import config
@@ -26,11 +23,15 @@ def q(text = ''): # easy way to exiting the script. useful while debugging
2623
print('> ', text)
2724
sys.exit()
2825

26+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
27+
print(f'\ndevice: {device}')
28+
2929
parser = argparse.ArgumentParser(description='Following are the arguments that can be passed form the terminal itself ! Cool huh ? :D')
3030
parser.add_argument('--data_path', type = str, default = 'NIH Chest X-rays', help = 'This is the path of the training data')
3131
# parser.add_argument('--test_path', type = str, default = os.path.join('hack-data-new','Scoring2/') , help = 'This is the path of the testing data')
32-
parser.add_argument('--bs', type = int, default = 256, help = 'batch size')
32+
parser.add_argument('--bs', type = int, default = 128, help = 'batch size')
3333
parser.add_argument('--lr', type = float, default = 1e-5, help = 'Learning Rate for the optimizer')
34+
parser.add_argument('--stage', type = int, default = 1, help = 'Stage, it decides which layers of the Neural Net to train')
3435
parser.add_argument('--loss_func', type = str, default = 'FocalLoss', choices = {'BCE', 'FocalLoss'}, help = 'loss function')
3536
parser.add_argument('-r','--resume', action = 'store_true') # args.resume will return True if -r or --resume is used in the terminal
3637
parser.add_argument('--ckpt', type = str, help = 'Path of the ckeckpoint that you wnat to load')
@@ -40,6 +41,11 @@ def q(text = ''): # easy way to exiting the script. useful while debugging
4041
if args.resume and args.test: # what if --test is not defiend at all ? test case hai ye ek
4142
q('The flow of this code has been designed either to train the model or to test it.\nPlease choose either --resume or --test')
4243

44+
stage = args.stage
45+
if not args.resume:
46+
print(f'\nOverwriting stage to 1, as the model training is being done from scratch')
47+
stage = 1
48+
4349
if args.test:
4450
print('TESTING THE MODEL')
4551
else:
@@ -72,7 +78,7 @@ def count_parameters(model):
7278
print('-------------------------------------')
7379

7480
# make the dataloaders
75-
batch_size = args.bs # 256 by default
81+
batch_size = args.bs # 128 by default
7682
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
7783
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size = batch_size, shuffle = not True)
7884
test_loader = torch.utils.data.DataLoader(XRayTest_dataset, batch_size = batch_size, shuffle = not True)
@@ -97,9 +103,9 @@ def count_parameters(model):
97103
# define the loss function
98104
if args.loss_func == 'FocalLoss': # by default
99105
from losses import FocalLoss
100-
loss_fn = FocalLoss(gamma = 2.)
106+
loss_fn = FocalLoss(device = device, gamma = 2.).to(device)
101107
elif args.loss_func == 'BCE':
102-
loss_fn = nn.BCEWithLogitsLoss()
108+
loss_fn = nn.BCEWithLogitsLoss().to(device)
103109

104110
# define the learning rate
105111
lr = args.lr
@@ -114,7 +120,8 @@ def count_parameters(model):
114120
# change the last linear layer
115121
num_ftrs = model.fc.in_features
116122
model.fc = nn.Linear(num_ftrs, len(XRayTrain_dataset.all_classes)) # 15 output classes
117-
123+
model.to(device)
124+
118125
print('----- STAGE 1 -----') # only training 'layer2', 'layer3', 'layer4' and 'fc'
119126
for name, param in model.named_parameters(): # all requires_grad by default, are True initially
120127
# print('{}: {}'.format(name, param.requires_grad)) # this shows True for all the parameters
@@ -131,81 +138,85 @@ def count_parameters(model):
131138

132139
else:
133140
if args.ckpt == None:
134-
q('ERROR: Please select a checkpoint to resume from')
141+
q('ERROR: Please select a valid checkpoint to resume from')
135142

136143
print('\nckpt loaded: {}'.format(args.ckpt))
137144
ckpt = torch.load(os.path.join(config.models_dir, args.ckpt))
138145

139146
# since we are resuming the training of the model
140147
epochs_till_now = ckpt['epochs']
141148
model = ckpt['model']
142-
149+
model.to(device)
150+
143151
# loading previous loss lists to collect future losses
144152
losses_dict = ckpt['losses_dict']
145153

146154
# printing some hyperparameters
147155
print('\n> loss_fn: {}'.format(loss_fn))
148156
print('> epochs_till_now: {}'.format(epochs_till_now))
149157
print('> batch_size: {}'.format(batch_size))
158+
print('> stage: {}'.format(stage))
150159
print('> lr: {}'.format(lr))
151160

152161
else: # testing
153162
if args.ckpt == None:
154163
q('ERROR: Please select a checkpoint to load the testing model from')
155164

156-
print('\nckpt loaded: {}'.format(args.ckpt))
165+
print('\ncheckpoint loaded: {}'.format(args.ckpt))
157166
ckpt = torch.load(os.path.join(config.models_dir, args.ckpt))
158167

159168
# since we are resuming the training of the model
160169
epochs_till_now = ckpt['epochs']
161170
model = ckpt['model']
162-
171+
163172
# loading previous loss lists to collect future losses
164173
losses_dict = ckpt['losses_dict']
165174

166175
# make changes(freezing/unfreezing the model's layers) in the following, for training the model for different stages
167-
if not args.test:
168-
if args.resume:
169-
'''
176+
if (not args.test) and (args.resume):
177+
178+
if stage == 1:
179+
170180
print('\n----- STAGE 1 -----') # only training 'layer2', 'layer3', 'layer4' and 'fc'
171181
for name, param in model.named_parameters(): # all requires_grad by default, are True initially
172182
# print('{}: {}'.format(name, param.requires_grad)) # this shows True for all the parameters
173183
if ('layer2' in name) or ('layer3' in name) or ('layer4' in name) or ('fc' in name):
174184
param.requires_grad = True
175185
else:
176186
param.requires_grad = False
177-
'''
178187

179-
'''
188+
elif stage == 2:
189+
180190
print('\n----- STAGE 2 -----') # only training 'layer3', 'layer4' and 'fc'
181191
for name, param in model.named_parameters():
182192
# print('{}: {}'.format(name, param.requires_grad)) # this shows True for all the parameters
183193
if ('layer3' in name) or ('layer4' in name) or ('fc' in name):
184194
param.requires_grad = True
185195
else:
186196
param.requires_grad = False
187-
'''
188197

189-
'''
198+
elif stage == 3:
199+
190200
print('\n----- STAGE 3 -----') # only training 'layer4' and 'fc'
191201
for name, param in model.named_parameters():
192202
# print('{}: {}'.format(name, param.requires_grad)) # this shows True for all the parameters
193203
if ('layer4' in name) or ('fc' in name):
194204
param.requires_grad = True
195205
else:
196206
param.requires_grad = False
197-
'''
198207

199-
# '''
208+
elif stage == 4:
209+
200210
print('\n----- STAGE 4 -----') # only training 'fc'
201211
for name, param in model.named_parameters():
202212
# print('{}: {}'.format(name, param.requires_grad)) # this shows True for all the parameters
203213
if ('fc' in name):
204214
param.requires_grad = True
205215
else:
206216
param.requires_grad = False
207-
# '''
208217

218+
219+
if not args.test:
209220
# checking the layers which are going to be trained (irrespective of args.resume)
210221
trainable_layers = []
211222
for name, param in model.named_parameters():
@@ -219,20 +230,14 @@ def count_parameters(model):
219230
print('\nwe have {} Million trainable parameters here in the {} model'.format(count_parameters(model), model.__class__.__name__))
220231

221232
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr = lr)
222-
step_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size = 2, gamma=0.8)
223-
224-
if args.resume:
225-
# the step_size and gamma defined will be overwritten by the ones saved in the state_dict of the previous step_lr_scheduler
226-
step_lr_scheduler.load_state_dict(ckpt['lr_scheduler_state_dict']) # this will use the state_dict of the saved lr_scheduler
227-
print('\nstep_lr_scheduler.state_dict(): ', step_lr_scheduler.state_dict())
228233

229234
# make changes in the parameters of the following 'fit' function
230-
fit(XRayTrain_dataset, train_loader, val_loader,
235+
fit(device, XRayTrain_dataset, train_loader, val_loader,
231236
test_loader, model, loss_fn,
232-
optimizer, step_lr_scheduler, losses_dict,
237+
optimizer, losses_dict,
233238
epochs_till_now = epochs_till_now, epochs = 3,
234-
log_interval = 5, save_interval = 1,
235-
lr = lr, bs = batch_size, stage_num = 4,
239+
log_interval = 25, save_interval = 1,
240+
lr = lr, bs = batch_size, stage = stage,
236241
test_only = args.test)
237242

238243
script_time = time.time() - script_start_time
@@ -247,11 +252,11 @@ def count_parameters(model):
247252
# epochs = 2
248253
# ##### STAGE 2 ##### FocalLoss lr = 3e-4
249254
# training layers = layer3, layer4, fc
250-
# epochs = 1
251-
# ##### STAGE 3 ##### FocalLoss lr = 1e-3
255+
# epochs = 5
256+
# ##### STAGE 3 ##### FocalLoss lr = 7e-4
252257
# training layers = layer4, fc
253-
# epochs = 3
258+
# epochs = 4
254259
# ##### STAGE 4 ##### FocalLoss lr = 1e-3
255260
# training layers = fc
256-
# epochs = 2
261+
# epochs = 3
257262
# '''

0 commit comments

Comments
 (0)