@@ -37,7 +37,11 @@ def forward(self, images):
3737
3838class Attention (nn .Module ):
3939
40- def __init__ (self , word_embeddings_dim , attention_dim , encoded_image_size ):
40+ def __init__ (self , word_embeddings_dim , attention_dim ):
41+ '''
42+ :param word_embeddings_dim: length of word embedding
43+ :param attention_dim: length of attention vector
44+ '''
4145 super (Attention , self ).__init__ ()
4246
4347 # Attention layer for encoder
@@ -50,19 +54,35 @@ def __init__(self, word_embeddings_dim, attention_dim, encoded_image_size):
5054 self .relu = nn .ReLU ()
5155
5256 def forward (self , encoder_out , decoder_out ):
57+ '''
58+ :param encoder_out: embedding of image
59+ :param decoder_out: embedding of previous word
60+ :return: weighted encoded image
61+ '''
5362
54- att_encoder_computed = self .att_encoder (encoder_out ) # (batch_size, encoded_image_size**2, attention_dim)
55- att_decoder_computed = self .att_decoder (decoder_out ) # (batch_size, attention_dim)
56- att = self .att_final (self .relu (att_encoder_computed + att_decoder_computed .unsqueeze (1 ))).squeeze (2 ) # (batch_size, encoded_image_size**2)
57- att_weights = self .softmax (att ) # (batch_size, 2048)
58- encoder_weighted = (encoder_out * att_weights .unsqueeze (2 )).sum (dim = 1 ) # (batch_size, encoder_dim)
63+ # Attention vector for image
64+ att_encoder_computed = self .att_encoder (encoder_out )
65+ # Attention vector for previous word
66+ att_decoder_computed = self .att_decoder (decoder_out )
67+ # Combining 2 attentions
68+ att = self .att_final (self .relu (att_encoder_computed + att_decoder_computed .unsqueeze (1 ))).squeeze (2 )
69+ # Weighting image parts based on attention
70+ att_weights = self .softmax (att )
71+ encoder_weighted = (encoder_out * att_weights .unsqueeze (2 )).sum (dim = 1 )
5972
6073 return encoder_weighted
6174
6275
6376class Decoder (nn .Module ):
6477
6578 def __init__ (self , vocab_size , word_embeddings_dim , attention_dim , decoder_hidden_size , encoded_image_size ):
79+ '''
80+ :param vocab_size: number of words in corpus
81+ :param word_embeddings_dim: length of word embedding
82+ :param attention_dim: length of attention vector
83+ :param decoder_hidden_size: hidden size of lstm
84+ :param encoded_image_size: size of each encoded image channel
85+ '''
6686 super (Decoder , self ).__init__ ()
6787
6888 self .encoded_image_size = encoded_image_size
@@ -73,7 +93,7 @@ def __init__(self, vocab_size, word_embeddings_dim, attention_dim, decoder_hidde
7393 self .LSTMCell = torch .nn .LSTMCell (2048 + word_embeddings_dim ,
7494 hidden_size = decoder_hidden_size , bias = True )
7595 self .embedding = nn .Embedding (num_embeddings = vocab_size , embedding_dim = word_embeddings_dim )
76- self .Attention = Attention (word_embeddings_dim , attention_dim , encoded_image_size )
96+ self .Attention = Attention (word_embeddings_dim , attention_dim )
7797 self .linear = torch .nn .Linear (decoder_hidden_size , vocab_size )
7898 self .h_init = torch .nn .Linear (2048 , decoder_hidden_size )
7999 self .c_init = torch .nn .Linear (2048 , decoder_hidden_size )
@@ -83,51 +103,53 @@ def __init__(self, vocab_size, word_embeddings_dim, attention_dim, decoder_hidde
83103
84104
85105 def forward (self , captions , encoder_out , captions_lengths ):
106+ '''
107+ :param captions: captions for images
108+ :param encoder_out: encoded images
109+ :param captions_lengths: lengths of captions
110+ :return:
111+ '''
86112
87-
88- # Размер батча (нужно для инициализации векторов)
113+ # Initialising vectors of predictions
89114 batch_size = encoder_out .size ()[0 ]
90- # Инициализирум вектор предсказаний размерности # (batch_size, max(captions_length), vocab_size) \
91- # (то есть для каждого наблюдения имеет вектор, состоящий из векторов вероятности появления каждого слова на конкретном месте предложения)
92115 predictions = torch .zeros (batch_size , max (captions_lengths ), self .vocab_size ).to (device ) # (batch_size, max(captions_length), vocab_size)
93- predictions [:, 0 , 0 ] = 1 # ставим вероятность в 1 для первого слова
94- # Выравниваем каналы (то есть было 2048 матриц размерностями encoded_image_size, encoded_image_size, \
95- # а стало 2048 векторов размерностями encoded_image_size**2)
96- encoder_out = encoder_out .view (batch_size , - 1 , 2048 ) # (batch_size, max(captions_length), 2048)
97- # Сортируем наблюдения в порядке убывания длины предложения
116+ # First word of each caption guruanteed to be <start>
117+ predictions [:, 0 , 0 ] = 1
118+ # Falttening channels
119+ encoder_out = encoder_out .view (batch_size , - 1 , 2048 )
120+ # Sort captions by their length (for faster loop)
98121 captions_lengths , sort_ind = captions_lengths .squeeze (1 ).sort (dim = 0 , descending = True )
99- encoder_out = encoder_out [sort_ind ] # (batch_size, max(captions_length), 2048)
122+ encoder_out = encoder_out [sort_ind ]
100123 captions = captions [sort_ind ]
101- # Делаем из слов эмбеддинги
102- embeddings = self .embedding (captions ) # (batch_size, max(captions_length), word_embeddings_dim)
103- # Инициализируем вектора LSTM для первого слова (с помощью картинки)
124+ # Embedding each word of captions
125+ embeddings = self .embedding (captions )
126+ # Initialising lstm vectors for first word
104127 h = self .h_init (encoder_out .mean (dim = 1 )) # (batch_size, decoder_hidden_size)
105128 c = self .c_init (encoder_out .mean (dim = 1 )) # (batch_size, decoder_hidden_size)
106129
107130
108131 for word_n in range (1 , max (captions_lengths )):
109- # Количество наблюдений, для которых длина предложения больше заданной длины
132+ # Number of captions with greater length
110133 batch_size_n = sum ([length > word_n for length in captions_lengths ])
111134
112- # Выбираем эмбеддинг слова, стоящего на позиции word_n - 1 (то есть эмбеддинг предыдущего слова)
135+ # Obtain embedding of previous word
113136 decoder_out = embeddings [:, (word_n - 1 )] # (batch_size, word_embeddings_dim)
114137
115138
116- # Механизм внимания
117- encoder_weighted = self .Attention (batch_size = batch_size_n ,
118- encoder_out = encoder_out [:batch_size_n ],
119- decoder_out = decoder_out [:batch_size_n ]) # (batch_size, encoded_image_size**2)
139+ # Attention mechanism
140+ encoder_weighted = self .Attention (encoder_out = encoder_out [:batch_size_n ],
141+ decoder_out = decoder_out [:batch_size_n ])
120142
121- gate = self .sigmoid (self .f_beta (h [:batch_size_n ])) # gating scalar, (batch_size_t, encoder_dim)
143+ gate = self .sigmoid (self .f_beta (h [:batch_size_n ]))
122144 encoder_weighted = gate * encoder_weighted
123145
124- # Конкатенируем информцию из механизма внимания и информацию о предыдущем слове
125- decoder_in = torch .cat ((encoder_weighted , decoder_out [:batch_size_n ]), 1 ) # (batch_size, encoded_image_size**2 + word_embeddings_dim)
146+ # Concatenating attention and previous word
147+ decoder_in = torch .cat ((encoder_weighted , decoder_out [:batch_size_n ]), 1 )
126148
127- # Предсказываем вероятности появления слов на текущей позиции
128- h , c = self .LSTMCell (decoder_in , (h [:batch_size_n ], c [:batch_size_n ])) # (batch_size, decoder_hidden_size)
129- predictions_word = self .linear (h ) # (batch_size, decoder_hidden_size)
130- # Записываем информацию о предсказанных вероятностях (еще не вероятностях) в вектор
149+ # Obtaining probabilities (not exectaly, because no softmax on this step) of word appearing on this step
150+ h , c = self .LSTMCell (decoder_in , (h [:batch_size_n ], c [:batch_size_n ]))
151+ predictions_word = self .linear (h )
152+ # Store probabilities (not exectaly, because no softmax on this step) in vector
131153 predictions [:batch_size_n , word_n , :] = predictions_word
132154
133155 return predictions , captions , captions_lengths , sort_ind
0 commit comments