Skip to content

Commit b4ad5d3

Browse files
committed
Minor cleanup
1 parent 6e4d798 commit b4ad5d3

File tree

5 files changed

+224
-33
lines changed

5 files changed

+224
-33
lines changed

config/jsfusion-whole.json

+3-3
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33
"pipeline": [
44
{
55
"model": "models.jsfusion.model.JsFusionLoader",
6-
"gpus": [0,1,2]
6+
"gpus": [0]
77
},
88
{
99
"model": "models.jsfusion.model.ResNetRunner",
10-
"gpus": [0,1,2,3,4,5]
10+
"gpus": [0]
1111
},
1212
{
1313
"model": "models.jsfusion.model.MCModelRunner",
14-
"gpus": [4,5]
14+
"gpus": [0]
1515
}
1616
]
1717
}

models/jsfusion/attention.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,19 @@ def add_timing_signal_nd(num_frames, video_channels):
88
shape = [1, num_frames, video_channels]
99
num_dims = len(shape) - 2
1010
channels = shape[-1]
11+
12+
position = torch.tensor(range(num_frames), dtype=torch.float32)
13+
position = torch.unsqueeze(position, dim=1)
14+
1115
num_timescales = channels // (num_dims * 2)
1216
log_timescale_increment = math.log(MAX_TIMESCALE / MIN_TIMESCALE) / (num_timescales - 1)
1317
inv_timescales = []
1418
for i in range(num_timescales):
1519
inv_timescales.append(1.0 * math.exp(-float(i) * log_timescale_increment))
16-
dim = 0
17-
length = shape[dim + 1]
18-
19-
position = torch.tensor(range(num_frames), dtype=torch.float32)
2020
inv_timescales = torch.tensor(inv_timescales, dtype=torch.float32)
21-
22-
position = torch.unsqueeze(position, dim=1)
23-
2421
inv_timescales = torch.unsqueeze(inv_timescales, dim=0)
2522

2623
scaled_time = position.matmul(inv_timescales)
27-
2824
signal = torch.cat([scaled_time.sin(), scaled_time.cos()], dim=1)
2925
signal = torch.unsqueeze(signal, 0)
3026

models/jsfusion/data_util.py

+211
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
"""Utility class used in JSFusion model, copied from the original author's code
2+
https://github.com/yj-yu/lsmdc/blob/master/videocap/datasets/data_util.py
3+
"""
4+
import time
5+
import numpy as np
6+
import re
7+
8+
9+
def clean_str(string, downcase=True):
10+
"""Tokenization/string cleaning for strings.
11+
12+
Taken from https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py
13+
"""
14+
string = re.sub(r"[^A-Za-z0-9(),!?\'\`(_____)]", " ", string)
15+
string = re.sub(r"\'s", " \'s", string)
16+
string = re.sub(r"\'ve", " \'ve", string)
17+
string = re.sub(r"n\'t", " n\'t", string)
18+
string = re.sub(r"\'re", " \'re", string)
19+
string = re.sub(r"\'d", " \'d", string)
20+
string = re.sub(r"\'ll", " \'ll", string)
21+
string = re.sub(r",", " , ", string)
22+
string = re.sub(r"!", " ! ", string)
23+
string = re.sub(r"\(", " \( ", string)
24+
string = re.sub(r"\)", " \) ", string)
25+
string = re.sub(r"\?", " \? ", string)
26+
string = re.sub(r"\s{2,}", " ", string)
27+
return string.strip().lower() if downcase else string.strip()
28+
29+
def recover_word(string):
30+
string = re.sub(r" \'s", "\'s", string)
31+
string = re.sub(r" ,", ",", string)
32+
return string
33+
34+
def clean_blank(blank_sent):
35+
"""Tokenizes and changes _____ to <START>
36+
<START> would be Answer position in FIB work.
37+
"""
38+
clean_sent = clean_str(blank_sent).split()
39+
return ['<START>' if x == '_____' else x for x in clean_sent]
40+
41+
42+
def clean_root(string):
43+
"""Removes unexpected character in root.
44+
"""
45+
return string
46+
47+
48+
def pad_sequences(sequences, pad_token="[PAD]", pad_location="LEFT", max_length=None):
49+
"""Pads all sequences to the same length. The length is defined by the longest sequence.
50+
Returns padded sequences.
51+
"""
52+
if not max_length:
53+
max_length = max(len(x) for x in sequences)
54+
55+
result = []
56+
for i in range(len(sequences)):
57+
sentence = sequences[i]
58+
num_padding = max_length - len(sentence)
59+
if num_padding == 0:
60+
new_sentence = sentence
61+
elif num_padding < 0:
62+
new_sentence = sentence[:num_padding]
63+
elif pad_location == "RIGHT":
64+
new_sentence = sentence + [pad_token] * num_padding
65+
elif pad_location == "LEFT":
66+
new_sentence = [pad_token] * num_padding + sentence
67+
else:
68+
print("Invalid pad_location. Specify LEFT or RIGHT.")
69+
result.append(new_sentence)
70+
return result
71+
72+
73+
def convert_sent_to_index(sentence, word_to_index):
74+
"""Converts sentence consisting of string to indexed sentence.
75+
"""
76+
return [word_to_index[word] if word in word_to_index.keys() else 0 for word in sentence]
77+
78+
79+
def batch_iter(data, batch_size, seed=None, fill=True):
80+
"""Generates a batch iterator for a dataset.
81+
"""
82+
random = np.random.RandomState(seed)
83+
data_length = len(data)
84+
num_batches = int(data_length / batch_size)
85+
if data_length % batch_size != 0:
86+
num_batches += 1
87+
88+
# Shuffle the data at each epoch
89+
shuffle_indices = random.permutation(np.arange(data_length))
90+
for batch_num in range(num_batches):
91+
start_index = batch_num * batch_size
92+
end_index = min((batch_num + 1) * batch_size, data_length)
93+
selected_indices = shuffle_indices[start_index:end_index]
94+
# If we don't have enough data left for a whole batch, fill it randomly
95+
if fill and end_index >= data_length:
96+
num_missing = batch_size - len(selected_indices)
97+
selected_indices = np.concatenate([selected_indices, random.randint(0, data_length, num_missing)])
98+
yield [data[i] for i in selected_indices]
99+
100+
101+
def fsr_iter(fsr_data, batch_size, random_seed=42, fill=True):
102+
"""fsr_data: one of LSMDCData.build_data(), [[video_features], [sentences], [roots]]
103+
return per iter: [[feature]*batch_size, [sentences]*batch_size, [roots]*batch]
104+
105+
Usage:
106+
train_data, val_data, test_data = LSMDCData.build_data()
107+
for features, sentences, roots in fsr_iter(train_data, 20, 10):
108+
feed_dict = {model.video_feature : features,
109+
model.sentences : sentences,
110+
model.roots : roots}
111+
"""
112+
113+
train_iter = batch_iter(list(zip(*fsr_data)), batch_size, fill=fill, seed=random_seed)
114+
return map(lambda batch: zip(*batch), train_iter)
115+
116+
117+
def preprocess_sents(descriptions, word_to_index, max_length):
118+
descriptions = [clean_str(sent).split() for sent in descriptions]
119+
# Add padding on the right to each sentence in order to keep the same lengths.
120+
descriptions = pad_sequences(descriptions, max_length=max_length)
121+
# Convert sentences from a list of string to the list of indices (int)
122+
descriptions = [convert_sent_to_index(sent, word_to_index) for sent in descriptions]
123+
124+
return descriptions
125+
# remove punctuation mark and special chars from root.
126+
127+
128+
def preprocess_roots(roots, word_to_index):
129+
roots = [clean_root(root) for root in roots]
130+
# convert string to int index.
131+
roots = [word_to_index[root] if root in word_to_index.keys() else 0 for root in roots]
132+
133+
return roots
134+
135+
136+
def pad_video(video_feature, dimension, padded_feature=None):
137+
"""Fills pad to video to have same length.
138+
Pad in Left.
139+
video = [pad,..., pad, frm1, frm2, ..., frmN]
140+
"""
141+
if padded_feature is None:
142+
padded_feature = np.zeros(dimension, dtype=np.float32)
143+
max_length = dimension[0]
144+
current_length = video_feature.shape[0]
145+
num_padding = max_length - current_length
146+
if num_padding == 0:
147+
padded_feature[:] = video_feature
148+
elif num_padding < 0:
149+
steps = np.linspace(0, current_length, num=max_length, endpoint=False, dtype=np.int32)
150+
padded_feature[:] = video_feature[steps]
151+
else:
152+
# about 0.7 sec
153+
padded_feature[num_padding:] = video_feature
154+
155+
return padded_feature
156+
157+
def repeat_pad_video(video_feature, dimension):
158+
padded_feature = np.zeros(dimension, dtype= np.float)
159+
max_length = dimension[0]
160+
current_length = video_feature.shape[0]
161+
162+
if current_length == max_length:
163+
padded_feature[:] = video_feature
164+
165+
elif current_length < max_length:
166+
tile_num = int(max_length / current_length)
167+
to_tile = np.ones(len(dimension), dtype=np.int32)
168+
to_tile[0] = tile_num
169+
remainder = max_length % current_length
170+
tiled_vid = np.tile(video_feature, to_tile)
171+
if remainder > 0:
172+
padded_feature[0:remainder] = video_feature[-remainder:]
173+
padded_feature[remainder:] = tiled_vid
174+
175+
else:
176+
steps = np.linspace(0, current_length, num=max_length, endpoint=False, dtype=np.int32)
177+
padded_feature[:] = video_feature[steps]
178+
return padded_feature
179+
180+
def stretch_pad_video(video_feature, dimension):
181+
padded_feature = np.zeros(dimension, dtype= np.float)
182+
max_length = dimension[0]
183+
current_length = video_feature.shape[0]
184+
185+
if current_length == max_length:
186+
padded_feature[:] = video_feature
187+
elif current_length < max_length:
188+
repeat_num = int((max_length-1) / current_length)+1
189+
tiled_vid = np.repeat(video_feature, repeat_num,0)
190+
steps = np.linspace(0, repeat_num*current_length, num=max_length, endpoint=False, dtype=np.int32)
191+
padded_feature[:] = tiled_vid[steps]
192+
else:
193+
steps = np.linspace(0, current_length, num=max_length, endpoint=False, dtype=np.int32)
194+
padded_feature[:] = video_feature[steps]
195+
return padded_feature
196+
197+
198+
def fill_mask(max_length, current_length, zero_location='LEFT'):
199+
num_padding = max_length - current_length
200+
if num_padding <= 0:
201+
mask = np.ones(max_length)
202+
elif zero_location == 'LEFT':
203+
mask = np.ones(max_length)
204+
for i in range(num_padding):
205+
mask[i] = 0
206+
elif zero_location == 'RIGHT':
207+
mask = np.zeros(max_length)
208+
for i in range(current_length):
209+
mask[i] = 1
210+
211+
return mask

models/jsfusion/model.py

-6
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010
import nvvl
1111
import os
1212

13-
NUM_FRAMES=40
14-
1513
class JsFusionVideoPathIterator(VideoPathIterator):
1614
def __init__(self):
1715
super(JsFusionVideoPathIterator, self).__init__()
@@ -60,11 +58,8 @@ def __call__(self, input):
6058
frames = frames.float()
6159
frames = frames.permute(0, 2, 1, 3, 4)
6260

63-
### TODO Directly apply this transform
6461
transform = transforms.Normalize(mean=[0.485, 0.456, 0.406],
6562
std=[0.229, 0.224, 0.225])
66-
67-
### TODO This logic seems to be simplified
6863
frames_tmp = []
6964
for frame in frames:
7065
frame = torch.squeeze(frame)
@@ -115,7 +110,6 @@ def __init__(self, device, num_frames = 40):
115110
self.model.eval()
116111

117112
def input_shape(self):
118-
# TODO Input shape
119113
return ((1, 40, 2048),)
120114

121115
def __call__(self, input):

models/jsfusion/module.py

+6-16
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,9 @@ def __init__(self, device, dropout_prob = 0.5, video_channels = 2048, num_frames
121121
self.final_bn3 = torch.nn.BatchNorm1d(128, eps=0.001, momentum=0.001)
122122
self.final_fc4 = torch.nn.Linear(128, 1)
123123
self.final_bn4 = torch.nn.BatchNorm1d(1, eps=0.001, momentum=0.001)
124-
125124
self.word2idx = hkl.load(os.path.join(os.environ['LSMDC_PATH'], 'hkls/common_word_to_index_py3.hkl'))
126-
125+
126+
127127
def video_embeddings(self, video, mask):
128128
# BxLxC
129129
embedded_feat_tmp = video + self.signal
@@ -144,7 +144,6 @@ def video_embeddings(self, video, mask):
144144
relu2 = self.relu2(conv2)
145145
bn2 = self.bn2(relu2)
146146

147-
148147
conv3 = self.conv3(bn2)
149148
relu3 = self.relu3(conv3)
150149
bn3 = self.bn3(relu3)
@@ -185,7 +184,6 @@ def word_embeddings(self, captions, caption_masks):
185184

186185
# 5BxLxH
187186
embedded_sentence = seq_embeddings * caption_masks
188-
print('embedded_sentence', embedded_sentence.size, embedded_sentence.device, embedded_sentence)
189187

190188
# 5BxLx1024
191189
outputs, _ = self.lstm(embedded_sentence)
@@ -202,7 +200,6 @@ def word_embeddings(self, captions, caption_masks):
202200
return rnn_output
203201

204202

205-
206203
def fusion(self, v, w, mask, caption_masks):
207204
# 5Bx512xL
208205
v = v.repeat(5, 1, 1)
@@ -302,20 +299,17 @@ def fusion_next(self, output1, mask, caption_masks):
302299

303300
cut_mask_indices = [i for i in range(cut_mask.shape[1]) if i % 2 == 1 and i < cut_mask.shape[1] - 1]
304301
cut_mask_indices = torch.tensor(cut_mask_indices)
305-
cut_mask_indices = cut_mask_indices.to(device=self.device, non_blocking=True)#cuda(non_blocking=True)
302+
cut_mask_indices = cut_mask_indices.to(device=self.device, non_blocking=True)
306303

307-
# cut_mask = torch.tensor([([0]*(max_len - l) + [1]*l) for l in cut_mask_len.cpu().numpy()],
308-
# dtype=torch.float32)
304+
# cut_mask = torch.tensor([([0]*(max_len - l) + [1]*l) for l in cut_mask_len.cpu().numpy()], dtype=torch.float32)
309305
cut_mask = torch.index_select(cut_mask, 1, cut_mask_indices)
310306

311307
cut_caption_masks_indices = [i for i in range(cut_caption_masks.shape[1]) if i % 2 == 1 and i > 1]
312308
cut_caption_masks_indices = torch.tensor(cut_caption_masks_indices)
313-
cut_caption_masks_indices = cut_caption_masks_indices.to(device=self.device, non_blocking=True)#cuda(non_blocking=True)
314-
309+
cut_caption_masks_indices = cut_caption_masks_indices.to(device=self.device, non_blocking=True)
315310

316-
# cut_caption_masks = torch.tensor([([1]*l + [0]*(max_len - l)) for l in cut_caption_masks_len.cpu().numpy()],
317-
# dtype=torch.float32)
318311

312+
# cut_caption_masks = torch.tensor([([1]*l + [0]*(max_len - l)) for l in cut_caption_masks_len.cpu().numpy()], dtype=torch.float32)
319313
cut_caption_masks = torch.index_select(cut_caption_masks, 1, cut_caption_masks_indices)
320314

321315
cut_mask_list.append(cut_mask.repeat(5, 1))
@@ -398,7 +392,6 @@ def fusion_next(self, output1, mask, caption_masks):
398392
return sum_state
399393

400394

401-
402395
def final(self, fusion_next):
403396
# 5Bx256
404397
a = self.final_fc1(fusion_next)
@@ -420,9 +413,6 @@ def final(self, fusion_next):
420413
a = self.final_bn4(a)
421414

422415
return torch.reshape(-a, (-1, 5))
423-
424-
425-
426416

427417

428418
def parse_sentences(self, word2idx, mc, max_length):

0 commit comments

Comments
 (0)