Skip to content

Commit 957c67a

Browse files
author
Iliaavilov
committed
Добавил комментарии к некоторым функциям и классам
1 parent 610f83c commit 957c67a

7 files changed

+1341
-182
lines changed

dataset.py

+9-12
Original file line numberDiff line numberDiff line change
@@ -7,32 +7,28 @@
77

88

99
class CaptionDataset(Dataset):
10-
"""
11-
A PyTorch Dataset class to be used in a PyTorch DataLoader to create batches.
12-
"""
1310

14-
def __init__(self, data_folder, split, transform=None):
15-
"""
16-
:param data_folder: folder where data files are stored
17-
:param split: split, one of 'TRAIN', 'VAL', or 'TEST'
18-
:param transform: image transform pipeline
19-
"""
11+
def __init__(self, data_folder, split):
12+
'''
13+
:param data_folder: Folder where files are stored
14+
:param split: which split of dataset (train, validation or test)
15+
'''
2016
self.split = split
2117
assert self.split in {'TRAIN', 'VAL', 'TEST'}
2218

2319
# Open hdf5 file where images are stored
24-
self.h = h5py.File(os.path.join(data_folder, self.split + '_IMAGES_' + 'coco.hdf5'), 'r')
20+
self.h = h5py.File(os.path.join(data_folder, self.split + '_IMAGES_' + 'flickr30k.hdf5'), 'r')
2521
self.imgs = self.h['images']
2622

2723
# Captions per image
2824
self.cpi = self.h.attrs['captions_per_image']
2925

3026
# Load encoded captions (completely into memory)
31-
with open(os.path.join(data_folder, self.split + '_CAPTIONS_' + 'coco.json'), 'r') as j:
27+
with open(os.path.join(data_folder, self.split + '_CAPTIONS_' + 'flickr30k.json'), 'r') as j:
3228
self.captions = json.load(j)
3329

3430
# Load caption lengths (completely into memory)
35-
with open(os.path.join(data_folder, self.split + '_CAPLENS_' + 'coco.json'), 'r') as j:
31+
with open(os.path.join(data_folder, self.split + '_CAPLENS_' + 'flickr30k.json'), 'r') as j:
3632
self.caplens = json.load(j)
3733

3834
# PyTorch transformation pipeline for the image (normalizing, etc.)
@@ -61,4 +57,5 @@ def __getitem__(self, i):
6157
return img, caption, caplen, all_captions
6258

6359
def __len__(self):
60+
6461
return self.dataset_size

model.py

+17-11
Original file line numberDiff line numberDiff line change
@@ -7,26 +7,30 @@
77
class Encoder(nn.Module):
88

99
def __init__(self, encoded_image_size=14):
10+
'''
11+
:param encoded_image_size: each encoded channel size of image will be encoded_image_size X encoded_image_size
12+
'''
13+
1014
super(Encoder, self).__init__()
1115
self.enc_image_size = encoded_image_size
12-
13-
# Загружаем натренированную resnet152
16+
# Load pretrained resnet model
1417
resnet = torchvision.models.resnet152(pretrained=True)
15-
# Убираем линейные слои (нам нужны только CNN)
18+
# Delete FC layers and leave only CNN
1619
modules = list(resnet.children())[:-2]
1720
self.resnet = nn.Sequential(*modules)
18-
# Ресайз фичей изображения к нужным размерам
21+
# Resize CNN features from resnet to appropriate size
1922
self.pooling = nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size))
23+
# Resnet parameters will not be modified during training
2024
for p in self.resnet.parameters():
2125
p.requires_grad = False
2226

2327
def forward(self, images):
2428

25-
# Извлекаем 2048 "каналов" фичей по 7X7 каждый
29+
# Obtain 2048 channels of features each of size 7X7
2630
out = self.resnet(images) # (batch_size, 2048, 7, 7)
27-
# Изменяем размер каналов до (encoded_image_size, encoded_image_size)
31+
# Reseize size to (encoded_image_size, encoded_image_size)
2832
out = self.pooling(out) # (batch_size, 2048, encoded_image_size, encoded_image_size)
29-
# Переставляем местами размерности (просто для удобства)
33+
# Change dimension places (just for convinience)
3034
out = out.permute(0, 2, 3, 1) # (batch_size, encoded_image_size, encoded_image_size, 2048)
3135
return out
3236

@@ -36,19 +40,23 @@ class Attention(nn.Module):
3640
def __init__(self, word_embeddings_dim, attention_dim, encoded_image_size):
3741
super(Attention, self).__init__()
3842

43+
# Attention layer for encoder
3944
self.att_encoder = nn.Linear(2048, attention_dim)
45+
# Attention layer for decoder
4046
self.att_decoder = torch.nn.Linear(word_embeddings_dim, attention_dim)
47+
# Final layer of attention
4148
self.att_final = torch.nn.Linear(attention_dim, 1)
4249
self.softmax = nn.Softmax(dim = 1)
4350
self.relu = nn.ReLU()
4451

45-
def forward(self, encoder_out, decoder_out, batch_size):
52+
def forward(self, encoder_out, decoder_out):
53+
4654
att_encoder_computed = self.att_encoder(encoder_out) # (batch_size, encoded_image_size**2, attention_dim)
4755
att_decoder_computed = self.att_decoder(decoder_out) # (batch_size, attention_dim)
4856
att = self.att_final(self.relu(att_encoder_computed + att_decoder_computed.unsqueeze(1))).squeeze(2) # (batch_size, encoded_image_size**2)
4957
att_weights = self.softmax(att) # (batch_size, 2048)
50-
5158
encoder_weighted = (encoder_out * att_weights.unsqueeze(2)).sum(dim=1) # (batch_size, encoder_dim)
59+
5260
return encoder_weighted
5361

5462

@@ -62,13 +70,11 @@ def __init__(self, vocab_size, word_embeddings_dim, attention_dim, decoder_hidde
6270
self.word_embeddings_dim = word_embeddings_dim
6371
self.vocab_size = vocab_size
6472
self.encoded_image_size = encoded_image_size
65-
6673
self.LSTMCell = torch.nn.LSTMCell(2048 + word_embeddings_dim,
6774
hidden_size=decoder_hidden_size, bias = True)
6875
self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=word_embeddings_dim)
6976
self.Attention = Attention(word_embeddings_dim, attention_dim, encoded_image_size)
7077
self.linear = torch.nn.Linear(decoder_hidden_size, vocab_size)
71-
7278
self.h_init = torch.nn.Linear(2048, decoder_hidden_size)
7379
self.c_init = torch.nn.Linear(2048, decoder_hidden_size)
7480
self.f_beta = nn.Linear(decoder_hidden_size, 2048) # linear layer to create a sigmoid-activated gate

model_training.py

+83-52
Original file line numberDiff line numberDiff line change
@@ -3,88 +3,119 @@
33
from nltk.translate.bleu_score import corpus_bleu
44
from utils import save_checkpoint
55
import neptune
6-
import time
76

8-
def validate(enc, dec, device, loss_fn, val_loader, wordmap, epoch):
7+
def train(enc, dec, device, loss_fn, train_loader, optimizer_decoder, optimizer_encoder, epoch):
8+
''' Train model
9+
:param enc: encoder part of model
10+
:param dec: decoder part of model
11+
:param device: on which device to train model
12+
:param loss_fn: loss function
13+
:param train_loader: pytorch loader of images
14+
:param optimizer_decoder: pytorch optimizer for decoder part of model
15+
:param optimizer_encoder: pytorch optimizer for encoder part of model
16+
:param epoch: current epoch of training
17+
:return: None
18+
'''
19+
20+
dec.train()
21+
enc.train()
22+
23+
dec = dec.to(device)
24+
enc = enc.to(device)
25+
26+
# iterate through batches of train loader
27+
for batch_n, (imgs, caps, caplens) in enumerate(train_loader):
28+
29+
imgs = imgs.to(device)
30+
caps = caps.to(device)
31+
caplens = caplens.to(device)
32+
33+
# Encode images
34+
enc_output = enc(imgs)
35+
36+
# Decode encodings and get captions
37+
dec_out, captions, captions_lengths, sort_ind = dec(captions=caps,
38+
encoder_out=enc_output,
39+
captions_lengths=caplens)
40+
41+
# Remove words which we did not decode at (e.g. max length of sentence in batch is 15 words,
42+
# so for sentence of 10 words we did not decode 5 words, and we have to skip them during loss computing)
43+
dec_out = pack_padded_sequence(dec_out, captions_lengths.cpu(), batch_first=True).data.to(device)
44+
captions = pack_padded_sequence(captions, captions_lengths.cpu(), batch_first=True).data.to(device)
45+
46+
loss = loss_fn(dec_out, captions)
47+
optimizer_decoder.zero_grad()
48+
optimizer_encoder.zero_grad()
49+
50+
loss.backward()
51+
52+
optimizer_decoder.step()
53+
optimizer_encoder.step()
54+
55+
if batch_n % 3000 == 0:
56+
save_checkpoint(epoch, batch_n, enc, dec, optimizer_encoder, optimizer_decoder)
57+
print('Current loss', loss.item())
58+
59+
# Log metric to neptune
60+
neptune.log_metric('loss', loss.item())
61+
62+
63+
def validate(enc, dec, device, val_loader, wordmap, epoch):
64+
''' Calculate validation metric
65+
:param val_loader: pytorch loader of images
66+
:param wordmap: dictionary mapping from word to word index
67+
:param epoch: current epoch of training
68+
:return: None
69+
'''
70+
971
enc.eval()
1072
dec.eval()
11-
references = list() # references (true captions) for calculating BLEU-4 score
12-
hypotheses = list() # hypotheses (predictions)
73+
74+
dec = dec.to(device)
75+
enc = enc.to(device)
76+
77+
references = list() # True captions
78+
hypotheses = list() # Predicted captions
79+
1380
with torch.no_grad():
81+
1482
for batch_n, (imgs, caps, caplens, allcaps) in enumerate(val_loader):
83+
1584
print(batch_n)
85+
1686
imgs = imgs.to(device)
1787
caps = caps.to(device)
1888
caplens = caplens.to(device)
89+
1990
enc_output = enc(imgs)
2091
dec_out, captions, captions_lengths, sort_ind = dec(captions=caps,
2192
encoder_out=enc_output,
2293
captions_lengths=caplens)
23-
2494
scores_copy = dec_out.clone()
25-
dec_out = pack_padded_sequence(dec_out.cpu(), captions_lengths.cpu(), batch_first=True).data.to(device)
26-
captions = pack_padded_sequence(captions.cpu(), captions_lengths.cpu(), batch_first=True).data.to(device)
2795

28-
loss = loss_fn(dec_out, captions)
2996

30-
allcaps = allcaps[sort_ind] # because images were sorted in the decoder
97+
allcaps = allcaps[sort_ind] # Resort because captions were sorted in decoder
98+
3199
for j in range(allcaps.shape[0]):
100+
32101
img_caps = allcaps[j].tolist()
33102
img_captions = list(
34103
map(lambda c: [w for w in c if w not in {wordmap['<start>'], wordmap['<pad>']}],
35104
img_caps)) # remove <start> and pads
36105
references.append(img_captions)
37106

38-
# Hypotheses
107+
# Take predicted captions for each image
39108
_, preds = torch.max(scores_copy, dim=2)
40109
preds = preds.tolist()
41110
temp_preds = list()
42111
for j, p in enumerate(preds):
43112
temp_preds.append(preds[j][:captions_lengths[j]]) # remove pads
44113
preds = temp_preds
45114
hypotheses.extend(preds)
115+
46116
# Calculate BLEU-4 scores
47117
bleu4 = corpus_bleu(references, hypotheses)
48-
neptune.log_metric('bleu4', bleu4)
49-
print('Epoch {}, BLEU4'.format(epoch), bleu4)
50-
51-
52-
def train(enc, dec, device, loss_fn, train_loader, optimizer_decoder, optimizer_encoder, epoch):
53-
54-
dec.train() # train mode (dropout and batchnorm is used)
55-
enc.train()
56-
57-
for batch_n, (imgs, caps, caplens) in enumerate(train_loader):
58-
start = time.time()
59-
imgs = imgs.to(device)
60-
caps = caps.to(device)
61-
caplens = caplens.to(device)
62-
enc_output = enc(imgs)
63-
dec_out, captions, captions_lengths, sort_ind = dec(captions=caps,
64-
encoder_out=enc_output,
65-
captions_lengths=caplens)
66-
# if batch_n % 20 == 0:
67-
# aaaa = [res.get(int(key)) for key in torch.argmax(dec_out[0], dim = 1)]
68-
# print('epoch:', epoch, 'batch', batch_n, aaaa)
69-
# img = Image.fromarray((unorm(imgs[0].cpu()).numpy()*255).astype('uint8').transpose(1, 2, 0))
70-
# img.save('{}-{}.png'.format(epoch, batch_n))
71-
# with open("captions.txt", "a") as f:
72-
# # Append 'hello' at the end of file
73-
# f.write("\n")
74-
# f.write(str(epoch) + '_' + str(batch_n) + '_' + str(aaaa))
75-
dec_out = pack_padded_sequence(dec_out, captions_lengths.cpu(), batch_first=True).data.to(device)
76-
captions = pack_padded_sequence(captions, captions_lengths.cpu(), batch_first=True).data.to(device)
77-
78-
loss = loss_fn(dec_out, captions)
79-
optimizer_decoder.zero_grad()
80-
optimizer_encoder.zero_grad()
81118

82-
loss.backward()
83-
84-
optimizer_decoder.step()
85-
optimizer_encoder.step()
86-
if batch_n % 3000 == 0:
87-
save_checkpoint(epoch, batch_n, enc, dec, optimizer_encoder, optimizer_decoder)
88-
print('Current loss', loss.item())
89-
90-
neptune.log_metric('loss', loss.item())
119+
# Log score to neptune and print metric
120+
neptune.log_metric('bleu4', bleu4)
121+
print('Epoch {}, BLEU4'.format(epoch), bleu4)

server/server.py

+7-14
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010

1111
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12-
enc, dec = load_models(checkpoint_name = 'checkpoint_7_0.pth.tar')
12+
enc, dec = load_models(checkpoint_name = 'checkpoint_6_0.pth.tar')
1313
wordmap, res = load_wordmap()
1414
app = Flask(__name__)
1515

@@ -23,7 +23,6 @@ def get_image_caption():
2323
no_image = False
2424
try:
2525
image = request.files['image']
26-
print(type(image))
2726
image_hash = hash(image)
2827
image.save('images/{}.png'.format(image_hash))
2928
image = imread('images/{}.png'.format(image_hash))
@@ -34,8 +33,9 @@ def get_image_caption():
3433
image = image_preprocessing(image)
3534
if (image.shape == (3, 256, 256)) & (np.max(image) <= 256):
3635
image = image_normalisation(image, device)
37-
predicted_captions, _ = captioning(enc, dec, image, wordmap, device, res)
38-
return ' '.join(predicted_captions[1:])
36+
predicted_captions, encoders_out = captioning(enc, dec, image, wordmap, device, res)
37+
return jsonify({'captions': ' '.join(predicted_captions[1:]),
38+
'encoders_out': encoders_out.tolist()})
3939
else:
4040
return 'Image is not in png'
4141

@@ -45,14 +45,13 @@ def get_video_captions():
4545
no_video = False
4646
try:
4747
video = request.files['video']
48-
print(type(video))
4948
video_hash = hash(video)
5049
video.save('videos/{}.mp4'.format(video_hash))
5150
except:
5251
no_video = True
5352
return jsonify('No video file in post requests')
5453
if no_video == False:
55-
video_to_screenshots('videos/{}.mp4'.format(video_hash), 'saved_screenshots', 200)
54+
video_to_screenshots('videos/{}.mp4'.format(video_hash), 'saved_screenshots', 10)
5655
list_of_files = [f for f in listdir('saved_screenshots') if isfile(join('saved_screenshots', f))]
5756
all_captions = []
5857
all_encoders_out = []
@@ -63,25 +62,19 @@ def get_video_captions():
6362
if (image.shape == (3, 256, 256)) & (np.max(image) <= 256):
6463
image = image_normalisation(image, device)
6564
predicted_captions, encoder_out = captioning(enc, dec, image, wordmap, device, res)
66-
all_captions.append(predicted_captions)
67-
all_encoders_out.append(encoder_out)
65+
all_captions.append(' '.join(predicted_captions[1:]))
66+
all_encoders_out.append(encoder_out.tolist())
6867

6968
return jsonify({'captions': all_captions,
7069
'encoders_out': all_encoders_out})
7170

72-
73-
74-
75-
76-
7771
@app.route('/get_captions', methods=['POST', 'GET'])
7872
def get_captions():
7973

8074
if request.method == 'POST':
8175
no_image = False
8276
try:
8377
image = request.files['image']
84-
print(type(image))
8578
image_hash = hash(image)
8679
image.save('images/{}.png'.format(image_hash))
8780
image = imread('images/{}.png'.format(image_hash))

0 commit comments

Comments
 (0)