3
3
from nltk .translate .bleu_score import corpus_bleu
4
4
from utils import save_checkpoint
5
5
import neptune
6
- import time
7
6
8
- def validate (enc , dec , device , loss_fn , val_loader , wordmap , epoch ):
7
+ def train (enc , dec , device , loss_fn , train_loader , optimizer_decoder , optimizer_encoder , epoch ):
8
+ ''' Train model
9
+ :param enc: encoder part of model
10
+ :param dec: decoder part of model
11
+ :param device: on which device to train model
12
+ :param loss_fn: loss function
13
+ :param train_loader: pytorch loader of images
14
+ :param optimizer_decoder: pytorch optimizer for decoder part of model
15
+ :param optimizer_encoder: pytorch optimizer for encoder part of model
16
+ :param epoch: current epoch of training
17
+ :return: None
18
+ '''
19
+
20
+ dec .train ()
21
+ enc .train ()
22
+
23
+ dec = dec .to (device )
24
+ enc = enc .to (device )
25
+
26
+ # iterate through batches of train loader
27
+ for batch_n , (imgs , caps , caplens ) in enumerate (train_loader ):
28
+
29
+ imgs = imgs .to (device )
30
+ caps = caps .to (device )
31
+ caplens = caplens .to (device )
32
+
33
+ # Encode images
34
+ enc_output = enc (imgs )
35
+
36
+ # Decode encodings and get captions
37
+ dec_out , captions , captions_lengths , sort_ind = dec (captions = caps ,
38
+ encoder_out = enc_output ,
39
+ captions_lengths = caplens )
40
+
41
+ # Remove words which we did not decode at (e.g. max length of sentence in batch is 15 words,
42
+ # so for sentence of 10 words we did not decode 5 words, and we have to skip them during loss computing)
43
+ dec_out = pack_padded_sequence (dec_out , captions_lengths .cpu (), batch_first = True ).data .to (device )
44
+ captions = pack_padded_sequence (captions , captions_lengths .cpu (), batch_first = True ).data .to (device )
45
+
46
+ loss = loss_fn (dec_out , captions )
47
+ optimizer_decoder .zero_grad ()
48
+ optimizer_encoder .zero_grad ()
49
+
50
+ loss .backward ()
51
+
52
+ optimizer_decoder .step ()
53
+ optimizer_encoder .step ()
54
+
55
+ if batch_n % 3000 == 0 :
56
+ save_checkpoint (epoch , batch_n , enc , dec , optimizer_encoder , optimizer_decoder )
57
+ print ('Current loss' , loss .item ())
58
+
59
+ # Log metric to neptune
60
+ neptune .log_metric ('loss' , loss .item ())
61
+
62
+
63
+ def validate (enc , dec , device , val_loader , wordmap , epoch ):
64
+ ''' Calculate validation metric
65
+ :param val_loader: pytorch loader of images
66
+ :param wordmap: dictionary mapping from word to word index
67
+ :param epoch: current epoch of training
68
+ :return: None
69
+ '''
70
+
9
71
enc .eval ()
10
72
dec .eval ()
11
- references = list () # references (true captions) for calculating BLEU-4 score
12
- hypotheses = list () # hypotheses (predictions)
73
+
74
+ dec = dec .to (device )
75
+ enc = enc .to (device )
76
+
77
+ references = list () # True captions
78
+ hypotheses = list () # Predicted captions
79
+
13
80
with torch .no_grad ():
81
+
14
82
for batch_n , (imgs , caps , caplens , allcaps ) in enumerate (val_loader ):
83
+
15
84
print (batch_n )
85
+
16
86
imgs = imgs .to (device )
17
87
caps = caps .to (device )
18
88
caplens = caplens .to (device )
89
+
19
90
enc_output = enc (imgs )
20
91
dec_out , captions , captions_lengths , sort_ind = dec (captions = caps ,
21
92
encoder_out = enc_output ,
22
93
captions_lengths = caplens )
23
-
24
94
scores_copy = dec_out .clone ()
25
- dec_out = pack_padded_sequence (dec_out .cpu (), captions_lengths .cpu (), batch_first = True ).data .to (device )
26
- captions = pack_padded_sequence (captions .cpu (), captions_lengths .cpu (), batch_first = True ).data .to (device )
27
95
28
- loss = loss_fn (dec_out , captions )
29
96
30
- allcaps = allcaps [sort_ind ] # because images were sorted in the decoder
97
+ allcaps = allcaps [sort_ind ] # Resort because captions were sorted in decoder
98
+
31
99
for j in range (allcaps .shape [0 ]):
100
+
32
101
img_caps = allcaps [j ].tolist ()
33
102
img_captions = list (
34
103
map (lambda c : [w for w in c if w not in {wordmap ['<start>' ], wordmap ['<pad>' ]}],
35
104
img_caps )) # remove <start> and pads
36
105
references .append (img_captions )
37
106
38
- # Hypotheses
107
+ # Take predicted captions for each image
39
108
_ , preds = torch .max (scores_copy , dim = 2 )
40
109
preds = preds .tolist ()
41
110
temp_preds = list ()
42
111
for j , p in enumerate (preds ):
43
112
temp_preds .append (preds [j ][:captions_lengths [j ]]) # remove pads
44
113
preds = temp_preds
45
114
hypotheses .extend (preds )
115
+
46
116
# Calculate BLEU-4 scores
47
117
bleu4 = corpus_bleu (references , hypotheses )
48
- neptune .log_metric ('bleu4' , bleu4 )
49
- print ('Epoch {}, BLEU4' .format (epoch ), bleu4 )
50
-
51
-
52
- def train (enc , dec , device , loss_fn , train_loader , optimizer_decoder , optimizer_encoder , epoch ):
53
-
54
- dec .train () # train mode (dropout and batchnorm is used)
55
- enc .train ()
56
-
57
- for batch_n , (imgs , caps , caplens ) in enumerate (train_loader ):
58
- start = time .time ()
59
- imgs = imgs .to (device )
60
- caps = caps .to (device )
61
- caplens = caplens .to (device )
62
- enc_output = enc (imgs )
63
- dec_out , captions , captions_lengths , sort_ind = dec (captions = caps ,
64
- encoder_out = enc_output ,
65
- captions_lengths = caplens )
66
- # if batch_n % 20 == 0:
67
- # aaaa = [res.get(int(key)) for key in torch.argmax(dec_out[0], dim = 1)]
68
- # print('epoch:', epoch, 'batch', batch_n, aaaa)
69
- # img = Image.fromarray((unorm(imgs[0].cpu()).numpy()*255).astype('uint8').transpose(1, 2, 0))
70
- # img.save('{}-{}.png'.format(epoch, batch_n))
71
- # with open("captions.txt", "a") as f:
72
- # # Append 'hello' at the end of file
73
- # f.write("\n")
74
- # f.write(str(epoch) + '_' + str(batch_n) + '_' + str(aaaa))
75
- dec_out = pack_padded_sequence (dec_out , captions_lengths .cpu (), batch_first = True ).data .to (device )
76
- captions = pack_padded_sequence (captions , captions_lengths .cpu (), batch_first = True ).data .to (device )
77
-
78
- loss = loss_fn (dec_out , captions )
79
- optimizer_decoder .zero_grad ()
80
- optimizer_encoder .zero_grad ()
81
118
82
- loss .backward ()
83
-
84
- optimizer_decoder .step ()
85
- optimizer_encoder .step ()
86
- if batch_n % 3000 == 0 :
87
- save_checkpoint (epoch , batch_n , enc , dec , optimizer_encoder , optimizer_decoder )
88
- print ('Current loss' , loss .item ())
89
-
90
- neptune .log_metric ('loss' , loss .item ())
119
+ # Log score to neptune and print metric
120
+ neptune .log_metric ('bleu4' , bleu4 )
121
+ print ('Epoch {}, BLEU4' .format (epoch ), bleu4 )
0 commit comments