Skip to content

Commit bdb3fc2

Browse files
committed
Port JsFusion to RnB
1 parent a4dc526 commit bdb3fc2

File tree

7 files changed

+898
-0
lines changed

7 files changed

+898
-0
lines changed

config/jsfusion-whole.json

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
{
2+
"video_path_iterator": "models.jsfusion.model.JsFusionVideoPathIterator",
3+
"pipeline": [
4+
{
5+
"model": "models.jsfusion.model.JsFusionLoader",
6+
"gpus": [0]
7+
},
8+
{
9+
"model": "models.jsfusion.model.ResNetRunner",
10+
"gpus": [0]
11+
},
12+
{
13+
"model": "models.jsfusion.model.MCModelRunner",
14+
"gpus": [0]
15+
}
16+
]
17+
}

models/jsfusion/__init__.py

Whitespace-only changes.

models/jsfusion/attention.py

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import torch
2+
import math
3+
4+
MIN_TIMESCALE=1.0
5+
MAX_TIMESCALE=1.0e4
6+
7+
def add_timing_signal_nd(num_frames, video_channels):
8+
shape = [1, num_frames, video_channels]
9+
num_dims = len(shape) - 2
10+
channels = shape[-1]
11+
12+
position = torch.tensor(range(num_frames), dtype=torch.float32)
13+
position = torch.unsqueeze(position, dim=1)
14+
15+
num_timescales = channels // (num_dims * 2)
16+
log_timescale_increment = math.log(MAX_TIMESCALE / MIN_TIMESCALE) / (num_timescales - 1)
17+
inv_timescales = []
18+
for i in range(num_timescales):
19+
inv_timescales.append(1.0 * math.exp(-float(i) * log_timescale_increment))
20+
inv_timescales = torch.tensor(inv_timescales, dtype=torch.float32)
21+
inv_timescales = torch.unsqueeze(inv_timescales, dim=0)
22+
23+
scaled_time = position.matmul(inv_timescales)
24+
signal = torch.cat([scaled_time.sin(), scaled_time.cos()], dim=1)
25+
signal = torch.unsqueeze(signal, 0)
26+
27+
return signal

models/jsfusion/data_util.py

+211
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
"""Utility class used in JSFusion model, copied from the original author's code
2+
https://github.com/yj-yu/lsmdc/blob/master/videocap/datasets/data_util.py
3+
"""
4+
import time
5+
import numpy as np
6+
import re
7+
8+
9+
def clean_str(string, downcase=True):
10+
"""Tokenization/string cleaning for strings.
11+
12+
Taken from https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py
13+
"""
14+
string = re.sub(r"[^A-Za-z0-9(),!?\'\`(_____)]", " ", string)
15+
string = re.sub(r"\'s", " \'s", string)
16+
string = re.sub(r"\'ve", " \'ve", string)
17+
string = re.sub(r"n\'t", " n\'t", string)
18+
string = re.sub(r"\'re", " \'re", string)
19+
string = re.sub(r"\'d", " \'d", string)
20+
string = re.sub(r"\'ll", " \'ll", string)
21+
string = re.sub(r",", " , ", string)
22+
string = re.sub(r"!", " ! ", string)
23+
string = re.sub(r"\(", " \( ", string)
24+
string = re.sub(r"\)", " \) ", string)
25+
string = re.sub(r"\?", " \? ", string)
26+
string = re.sub(r"\s{2,}", " ", string)
27+
return string.strip().lower() if downcase else string.strip()
28+
29+
def recover_word(string):
30+
string = re.sub(r" \'s", "\'s", string)
31+
string = re.sub(r" ,", ",", string)
32+
return string
33+
34+
def clean_blank(blank_sent):
35+
"""Tokenizes and changes _____ to <START>
36+
<START> would be Answer position in FIB work.
37+
"""
38+
clean_sent = clean_str(blank_sent).split()
39+
return ['<START>' if x == '_____' else x for x in clean_sent]
40+
41+
42+
def clean_root(string):
43+
"""Removes unexpected character in root.
44+
"""
45+
return string
46+
47+
48+
def pad_sequences(sequences, pad_token="[PAD]", pad_location="LEFT", max_length=None):
49+
"""Pads all sequences to the same length. The length is defined by the longest sequence.
50+
Returns padded sequences.
51+
"""
52+
if not max_length:
53+
max_length = max(len(x) for x in sequences)
54+
55+
result = []
56+
for i in range(len(sequences)):
57+
sentence = sequences[i]
58+
num_padding = max_length - len(sentence)
59+
if num_padding == 0:
60+
new_sentence = sentence
61+
elif num_padding < 0:
62+
new_sentence = sentence[:num_padding]
63+
elif pad_location == "RIGHT":
64+
new_sentence = sentence + [pad_token] * num_padding
65+
elif pad_location == "LEFT":
66+
new_sentence = [pad_token] * num_padding + sentence
67+
else:
68+
print("Invalid pad_location. Specify LEFT or RIGHT.")
69+
result.append(new_sentence)
70+
return result
71+
72+
73+
def convert_sent_to_index(sentence, word_to_index):
74+
"""Converts sentence consisting of string to indexed sentence.
75+
"""
76+
return [word_to_index[word] if word in word_to_index.keys() else 0 for word in sentence]
77+
78+
79+
def batch_iter(data, batch_size, seed=None, fill=True):
80+
"""Generates a batch iterator for a dataset.
81+
"""
82+
random = np.random.RandomState(seed)
83+
data_length = len(data)
84+
num_batches = int(data_length / batch_size)
85+
if data_length % batch_size != 0:
86+
num_batches += 1
87+
88+
# Shuffle the data at each epoch
89+
shuffle_indices = random.permutation(np.arange(data_length))
90+
for batch_num in range(num_batches):
91+
start_index = batch_num * batch_size
92+
end_index = min((batch_num + 1) * batch_size, data_length)
93+
selected_indices = shuffle_indices[start_index:end_index]
94+
# If we don't have enough data left for a whole batch, fill it randomly
95+
if fill and end_index >= data_length:
96+
num_missing = batch_size - len(selected_indices)
97+
selected_indices = np.concatenate([selected_indices, random.randint(0, data_length, num_missing)])
98+
yield [data[i] for i in selected_indices]
99+
100+
101+
def fsr_iter(fsr_data, batch_size, random_seed=42, fill=True):
102+
"""fsr_data: one of LSMDCData.build_data(), [[video_features], [sentences], [roots]]
103+
return per iter: [[feature]*batch_size, [sentences]*batch_size, [roots]*batch]
104+
105+
Usage:
106+
train_data, val_data, test_data = LSMDCData.build_data()
107+
for features, sentences, roots in fsr_iter(train_data, 20, 10):
108+
feed_dict = {model.video_feature : features,
109+
model.sentences : sentences,
110+
model.roots : roots}
111+
"""
112+
113+
train_iter = batch_iter(list(zip(*fsr_data)), batch_size, fill=fill, seed=random_seed)
114+
return map(lambda batch: zip(*batch), train_iter)
115+
116+
117+
def preprocess_sents(descriptions, word_to_index, max_length):
118+
descriptions = [clean_str(sent).split() for sent in descriptions]
119+
# Add padding on the right to each sentence in order to keep the same lengths.
120+
descriptions = pad_sequences(descriptions, max_length=max_length)
121+
# Convert sentences from a list of string to the list of indices (int)
122+
descriptions = [convert_sent_to_index(sent, word_to_index) for sent in descriptions]
123+
124+
return descriptions
125+
# remove punctuation mark and special chars from root.
126+
127+
128+
def preprocess_roots(roots, word_to_index):
129+
roots = [clean_root(root) for root in roots]
130+
# convert string to int index.
131+
roots = [word_to_index[root] if root in word_to_index.keys() else 0 for root in roots]
132+
133+
return roots
134+
135+
136+
def pad_video(video_feature, dimension, padded_feature=None):
137+
"""Fills pad to video to have same length.
138+
Pad in Left.
139+
video = [pad,..., pad, frm1, frm2, ..., frmN]
140+
"""
141+
if padded_feature is None:
142+
padded_feature = np.zeros(dimension, dtype=np.float32)
143+
max_length = dimension[0]
144+
current_length = video_feature.shape[0]
145+
num_padding = max_length - current_length
146+
if num_padding == 0:
147+
padded_feature[:] = video_feature
148+
elif num_padding < 0:
149+
steps = np.linspace(0, current_length, num=max_length, endpoint=False, dtype=np.int32)
150+
padded_feature[:] = video_feature[steps]
151+
else:
152+
# about 0.7 sec
153+
padded_feature[num_padding:] = video_feature
154+
155+
return padded_feature
156+
157+
def repeat_pad_video(video_feature, dimension):
158+
padded_feature = np.zeros(dimension, dtype= np.float)
159+
max_length = dimension[0]
160+
current_length = video_feature.shape[0]
161+
162+
if current_length == max_length:
163+
padded_feature[:] = video_feature
164+
165+
elif current_length < max_length:
166+
tile_num = int(max_length / current_length)
167+
to_tile = np.ones(len(dimension), dtype=np.int32)
168+
to_tile[0] = tile_num
169+
remainder = max_length % current_length
170+
tiled_vid = np.tile(video_feature, to_tile)
171+
if remainder > 0:
172+
padded_feature[0:remainder] = video_feature[-remainder:]
173+
padded_feature[remainder:] = tiled_vid
174+
175+
else:
176+
steps = np.linspace(0, current_length, num=max_length, endpoint=False, dtype=np.int32)
177+
padded_feature[:] = video_feature[steps]
178+
return padded_feature
179+
180+
def stretch_pad_video(video_feature, dimension):
181+
padded_feature = np.zeros(dimension, dtype= np.float)
182+
max_length = dimension[0]
183+
current_length = video_feature.shape[0]
184+
185+
if current_length == max_length:
186+
padded_feature[:] = video_feature
187+
elif current_length < max_length:
188+
repeat_num = int((max_length-1) / current_length)+1
189+
tiled_vid = np.repeat(video_feature, repeat_num,0)
190+
steps = np.linspace(0, repeat_num*current_length, num=max_length, endpoint=False, dtype=np.int32)
191+
padded_feature[:] = tiled_vid[steps]
192+
else:
193+
steps = np.linspace(0, current_length, num=max_length, endpoint=False, dtype=np.int32)
194+
padded_feature[:] = video_feature[steps]
195+
return padded_feature
196+
197+
198+
def fill_mask(max_length, current_length, zero_location='LEFT'):
199+
num_padding = max_length - current_length
200+
if num_padding <= 0:
201+
mask = np.ones(max_length)
202+
elif zero_location == 'LEFT':
203+
mask = np.ones(max_length)
204+
for i in range(num_padding):
205+
mask[i] = 0
206+
elif zero_location == 'RIGHT':
207+
mask = np.zeros(max_length)
208+
for i in range(current_length):
209+
mask[i] = 1
210+
211+
return mask

models/jsfusion/model.py

+121
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
from models.jsfusion.module import ResNetFeatureExtractor
2+
from models.jsfusion.module import MCModel
3+
from models.jsfusion.sampler import FixedSampler
4+
5+
from runner_model import RunnerModel
6+
from video_path_provider import VideoPathIterator
7+
from itertools import cycle
8+
from torchvision import transforms
9+
import torch
10+
import nvvl
11+
import os
12+
13+
class JsFusionVideoPathIterator(VideoPathIterator):
14+
def __init__(self):
15+
super(JsFusionVideoPathIterator, self).__init__()
16+
17+
videos = []
18+
video_dir = os.path.join(os.environ['LSMDC_PATH'], 'mp4s')
19+
for video in os.listdir(video_dir):
20+
videos.append(os.path.join(video_dir, video))
21+
22+
if len(videos) <= 0:
23+
raise Exception('No video available.')
24+
25+
self.videos_iter = cycle(videos)
26+
27+
def __iter__(self):
28+
return self.videos_iter
29+
30+
class JsFusionLoader(RunnerModel):
31+
"""Impl of loading video frames using NVVL, for the R(2+1)D model."""
32+
def __init__(self, device):
33+
self.loader = nvvl.RnBLoader(width=224, height=224,
34+
consecutive_frames=1, device_id=device.index,
35+
sampler=FixedSampler(num_frames=40))
36+
37+
samples = [
38+
os.path.join(os.environ['LSMDC_PATH'], 'mp4s/1004_Juno_00.00.32.849-00.00.35.458.mp4'),
39+
os.path.join(os.environ['LSMDC_PATH'], 'mp4s/1004_Juno_00.00.35.642-00.00.45.231.mp4'),
40+
os.path.join(os.environ['LSMDC_PATH'], 'mp4s/1004_Juno_00.00.49.801-00.00.59.450.mp4')]
41+
42+
# warm up GPU with a few inferences
43+
for sample in samples:
44+
self.loader.loadfile(sample)
45+
for frames in self.loader:
46+
pass
47+
self.loader.flush()
48+
49+
def __call__(self, input):
50+
_, file_path = input
51+
self.loader.loadfile(file_path)
52+
for frames in self.loader:
53+
pass
54+
self.loader.flush()
55+
56+
57+
# frames: (40, 3, 1, 224, 224)
58+
frames = frames.float()
59+
frames = frames.permute(0, 2, 1, 3, 4)
60+
61+
transform = transforms.Normalize(mean=[0.485, 0.456, 0.406],
62+
std=[0.229, 0.224, 0.225])
63+
frames_tmp = []
64+
for frame in frames:
65+
frame = torch.squeeze(frame)
66+
frame /= 255
67+
frame = transform(frame)
68+
frames_tmp.append(frame)
69+
frames = torch.stack(frames_tmp)
70+
# frames: (40, 3, 224, 224)
71+
72+
filename = os.path.basename(file_path)
73+
out = (frames, filename)
74+
return out
75+
76+
def __del__(self):
77+
self.loader.close()
78+
79+
def input_shape(self):
80+
return None
81+
82+
@staticmethod
83+
def output_shape():
84+
return ((40, 3, 224, 224),)
85+
86+
87+
class ResNetRunner(RunnerModel):
88+
def __init__(self, device, num_frames = 40):
89+
super(ResNetRunner, self).__init__(device)
90+
self.model = ResNetFeatureExtractor(num_frames).to(device)
91+
self.model.float()
92+
self.model.eval()
93+
94+
def input_shape(self):
95+
return ((40, 3, 224, 224),)
96+
97+
@staticmethod
98+
def output_shape():
99+
return ((1, 40, 2048),)
100+
101+
def __call__(self, input):
102+
return self.model(input)
103+
104+
105+
class MCModelRunner(RunnerModel):
106+
def __init__(self, device, num_frames = 40):
107+
super(MCModelRunner, self).__init__(device)
108+
self.model = MCModel(device).to(device)
109+
self.model.float()
110+
self.model.eval()
111+
112+
def input_shape(self):
113+
return ((1, 40, 2048),)
114+
115+
def __call__(self, input):
116+
return self.model(input)
117+
118+
@staticmethod
119+
def output_shape():
120+
return ((1,),)
121+

0 commit comments

Comments
 (0)