Skip to content

Commit 69cdd4d

Browse files
Improved image loading and added image caching
1 parent c09392f commit 69cdd4d

File tree

2 files changed

+108
-88
lines changed

2 files changed

+108
-88
lines changed

Diff for: .gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
__pycache__/
22
*.py[cod]
3-
*$py.class
3+
*$py.class
4+
.vscode/settings.json

Diff for: run.py

+106-87
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
input_folder = os.path.normpath(args.input)
4040
output_folder = os.path.normpath(args.output)
4141

42+
4243
def chop_forward(x, model, scale, shave=16, min_size=5000, nGPUs=1):
4344
# divide into 4 patches
4445
b, n, c, h, w = x.size()
@@ -50,7 +51,6 @@ def chop_forward(x, model, scale, shave=16, min_size=5000, nGPUs=1):
5051
x[:, :, :, (h - h_size):h, 0:w_size],
5152
x[:, :, :, (h - h_size):h, (w - w_size):w]]
5253

53-
5454
if w_size * h_size < min_size:
5555
outputlist = []
5656
for i in range(0, 4, nGPUs):
@@ -61,7 +61,7 @@ def chop_forward(x, model, scale, shave=16, min_size=5000, nGPUs=1):
6161
outputlist.append(output_batch.data)
6262
else:
6363
outputlist = [
64-
chop_forward(patch, model, scale, shave, min_size, nGPUs) \
64+
chop_forward(patch, model, scale, shave, min_size, nGPUs)
6565
for patch in inputlist]
6666

6767
h, w = scale * h, scale * w
@@ -76,12 +76,16 @@ def chop_forward(x, model, scale, shave=16, min_size=5000, nGPUs=1):
7676
if len(out.shape) < 4:
7777
outputlist[idx] = out.unsqueeze(0)
7878
output[:, :, 0:h_half, 0:w_half] = outputlist[0][:, :, 0:h_half, 0:w_half]
79-
output[:, :, 0:h_half, w_half:w] = outputlist[1][:, :, 0:h_half, (w_size - w + w_half):w_size]
80-
output[:, :, h_half:h, 0:w_half] = outputlist[2][:, :, (h_size - h + h_half):h_size, 0:w_half]
81-
output[:, :, h_half:h, w_half:w] = outputlist[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size]
79+
output[:, :, 0:h_half, w_half:w] = outputlist[1][:,
80+
:, 0:h_half, (w_size - w + w_half):w_size]
81+
output[:, :, h_half:h, 0:w_half] = outputlist[2][:,
82+
:, (h_size - h + h_half):h_size, 0:w_half]
83+
output[:, :, h_half:h, w_half:w] = outputlist[3][:, :,
84+
(h_size - h + h_half):h_size, (w_size - w + w_half):w_size]
8285

8386
return output.float().cpu()
8487

88+
8589
def main():
8690
state_dict = torch.load(args.model)
8791

@@ -92,7 +96,7 @@ def main():
9296
keys = state_dict.keys()
9397
# ESRGAN RRDB SR net
9498
if 'SR.model.1.sub.0.RDB1.conv1.0.weight' in keys:
95-
# extract model information
99+
# extract model information
96100
scale2 = 0
97101
max_part = 0
98102
for part in list(state_dict):
@@ -113,13 +117,14 @@ def main():
113117
nf = state_dict['SR.model.0.weight'].shape[0]
114118

115119
if scale == 2:
116-
if state_dict['OFR.SR.1.weight'].shape[0] == 576:
120+
if state_dict['OFR.SR.1.weight'].shape[0] == 576:
117121
scale = 3
118122

119123
frame_size = state_dict['SR.model.0.weight'].shape[1]
120124
num_frames = (((frame_size - 3) // (3 * (scale ** 2))) + 1)
121125

122-
model = SOFVSR.SOFVSR(scale=scale, n_frames=num_frames, channels=num_channels, SR_net='rrdb', sr_nf=nf, sr_nb=nb, img_ch=3, sr_gaussian_noise=False)
126+
model = SOFVSR.SOFVSR(scale=scale, n_frames=num_frames, channels=num_channels,
127+
SR_net='rrdb', sr_nf=nf, sr_nb=nb, img_ch=3, sr_gaussian_noise=False)
123128
only_y = False
124129
# Default SOFVSR SR net
125130
else:
@@ -138,7 +143,8 @@ def main():
138143
# Extract num_frames from model
139144
frame_size = state_dict['SR.body.0.weight'].shape[1]
140145
num_frames = (((frame_size - 1) // scale ** 2) + 1)
141-
model = SOFVSR.SOFVSR(scale=scale, n_frames=num_frames, channels=num_channels, SR_net='sofvsr', img_ch=1)
146+
model = SOFVSR.SOFVSR(scale=scale, n_frames=num_frames,
147+
channels=num_channels, SR_net='sofvsr', img_ch=1)
142148
only_y = True
143149

144150
# Create model
@@ -151,10 +157,12 @@ def main():
151157

152158
# Grabs video metadata information
153159
probe = ffmpeg.probe(args.input)
154-
video_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None)
160+
video_stream = next(
161+
(stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None)
155162
width = int(video_stream['width'])
156163
height = int(video_stream['height'])
157-
framerate = int(video_stream['r_frame_rate'].split('/')[0]) / int(video_stream['r_frame_rate'].split('/')[1])
164+
framerate = int(video_stream['r_frame_rate'].split(
165+
'/')[0]) / int(video_stream['r_frame_rate'].split('/')[1])
158166
vcodec = 'libx264'
159167
crf = args.crf
160168

@@ -172,7 +180,7 @@ def main():
172180
.reshape([-1, height, width, 3])
173181
)
174182

175-
# Convert numpy array into frame list
183+
# Convert numpy array into frame list
176184
images = []
177185
for i in range(video.shape[0]):
178186
frame = cv2.cvtColor(video[i], cv2.COLOR_RGB2BGR)
@@ -181,10 +189,10 @@ def main():
181189
# Open output file writer
182190
process = (
183191
ffmpeg
184-
.input('pipe:', format='rawvideo', pix_fmt='rgb24', s='{}x{}'.format(width * scale, height * scale))
185-
.output(args.output, pix_fmt='yuv420p', vcodec=vcodec, r=framerate, crf=crf, preset='veryfast')
186-
.overwrite_output()
187-
.run_async(pipe_stdin=True)
192+
.input('pipe:', format='rawvideo', pix_fmt='rgb24', s='{}x{}'.format(width * scale, height * scale))
193+
.output(args.output, pix_fmt='yuv420p', vcodec=vcodec, r=framerate, crf=crf, preset='veryfast')
194+
.overwrite_output()
195+
.run_async(pipe_stdin=True)
188196
)
189197
# Regular case with input/output frame images
190198
else:
@@ -193,99 +201,109 @@ def main():
193201
for file in sorted(files):
194202
if file.split('.')[-1].lower() in ['png', 'jpg', 'jpeg', 'gif', 'bmp', 'tiff', 'tga']:
195203
images.append(os.path.join(root, file))
196-
204+
197205
# Pad beginning and end frames so they get included in output
198-
images.insert(0, images[0])
199-
images.append(images[-1])
206+
num_padding = (num_frames - 1) // 2
207+
for _ in range(num_padding):
208+
images.insert(0, images[0])
209+
images.append(images[-1])
210+
211+
previous_lr_list = []
200212

201213
# Inference loop
202-
for idx, path in enumerate(images[1:-1], 0):
203-
idx_center = (num_frames - 1) // 2
204-
idx_frame = idx
205-
214+
for idx in range(num_padding, len(images) - num_padding):
215+
206216
# Only print this if processing frames
207217
if not is_video:
208-
img_name = os.path.splitext(os.path.basename(path))[0]
209-
print(idx_frame, img_name)
210-
211-
# read LR frames
212-
LR_list = []
213-
LR_bicubic = None
214-
for i_frame in range(num_frames):
215-
# Last and second to last frames
216-
if idx == len(images)-2 and num_frames == 3:
217-
# print("second to last frame:", i_frame)
218-
if i_frame == 0:
219-
LR_img = images[idx] if is_video else cv2.imread(images[idx_frame], cv2.IMREAD_COLOR)
220-
else:
221-
LR_img = images[idx+1] if is_video else cv2.imread(images[idx_frame+1], cv2.IMREAD_COLOR)
222-
elif idx == len(images)-1 and num_frames == 3:
223-
# print("last frame:", i_frame)
224-
LR_img = images[idx] if is_video else cv2.imread(images[idx_frame], cv2.IMREAD_COLOR)
225-
# Every other internal frame
226-
else:
227-
# print("normal frame:", idx_frame)
228-
LR_img = images[idx+i_frame] if is_video else cv2.imread(images[idx_frame+i_frame], cv2.IMREAD_COLOR)
229-
218+
img_name = os.path.splitext(os.path.basename(images[idx]))[0]
219+
print(idx - num_padding, img_name)
220+
221+
# First pass
222+
if idx == num_padding:
223+
LR_list = []
224+
LR_bicubic = None
225+
# Load all beginning images on either side of current index
226+
# E.g. num_frames = 7, from -3 to 3
227+
for i in range(-num_padding, num_padding + 1):
228+
# Read image or select video frame
229+
LR_img = images[idx + i] if is_video else cv2.imread(
230+
images[idx + i], cv2.IMREAD_COLOR)
231+
if not only_y:
232+
# TODO: Figure out why this is necessary
233+
LR_img = cv2.cvtColor(LR_img, cv2.COLOR_BGR2RGB)
234+
LR_list.append(LR_img)
235+
# Other passes
236+
else:
237+
# Remove beginning frame from cached list
238+
LR_list = previous_lr_list[1:]
239+
# Load next image or video frame
240+
new_img = images[idx + num_padding] if is_video else cv2.imread(
241+
images[idx + num_padding], cv2.IMREAD_COLOR)
230242
if not only_y:
231-
LR_img = cv2.cvtColor(LR_img, cv2.COLOR_BGR2RGB)
232-
233-
# get the bicubic upscale of the center frame to concatenate for SR
234-
if only_y and i_frame == idx_center:
235-
if args.denoise:
236-
LR_bicubic = cv2.blur(LR_img, (3,3))
237-
else:
238-
LR_bicubic = LR_img
239-
LR_bicubic = util.imresize_np(img=LR_bicubic, scale=scale) # bicubic upscale
240-
241-
if only_y:
242-
# extract Y channel from frames
243-
# normal path, only Y for both
244-
LR_img = util.bgr2ycbcr(LR_img, only_y=True)
245-
246-
# expand Y images to add the channel dimension
247-
# normal path, only Y for both
248-
LR_img = util.fix_img_channels(LR_img, 1)
243+
# TODO: Figure out why this is necessary
244+
new_img = cv2.cvtColor(LR_img, cv2.COLOR_BGR2RGB)
245+
LR_list.append(new_img)
246+
# Cache current list for next iter
247+
previous_lr_list = LR_list
249248

250-
LR_list.append(LR_img) # h, w, c
249+
# Convert LR_list to grayscale
250+
if only_y:
251+
gray_lr_list = []
252+
LR_bicubic = LR_list[num_padding]
253+
for i in range(len(LR_list)):
254+
gray_lr = util.bgr2ycbcr(LR_list[i], only_y=True)
255+
gray_lr = util.fix_img_channels(gray_lr, 1)
256+
gray_lr_list.append(gray_lr)
257+
LR_list = gray_lr_list
258+
259+
# Get the bicubic upscale of the center frame to concatenate for SR
260+
if only_y:
261+
if args.denoise:
262+
LR_bicubic = cv2.blur(LR_bicubic, (3, 3))
263+
else:
264+
LR_bicubic = LR_bicubic
265+
LR_bicubic = util.imresize_np(
266+
img=LR_bicubic, scale=scale) # bicubic upscale
251267

252-
if not only_y:
253-
h_LR, w_LR, c = LR_img.shape
268+
if not only_y:
269+
h_LR, w_LR, c = LR_list[0].shape
254270

255271
if not only_y:
256272
t = num_frames
257-
LR = [np.asarray(LT) for LT in LR_list] # list -> numpy # input: list (contatin numpy: [H,W,C])
258-
LR = np.asarray(LR) # numpy, [T,H,W,C]
259-
LR = LR.transpose(1,2,3,0).reshape(h_LR, w_LR, -1) # numpy, [Hl',Wl',CT]
273+
# list -> numpy # input: list (contatin numpy: [H,W,C])
274+
LR = [np.asarray(LT) for LT in LR_list]
275+
LR = np.asarray(LR) # numpy, [T,H,W,C]
276+
LR = LR.transpose(1, 2, 3, 0).reshape(
277+
h_LR, w_LR, -1) # numpy, [Hl',Wl',CT]
260278
else:
261-
LR = np.concatenate((LR_list), axis=2) # h, w, t
279+
LR = np.concatenate((LR_list), axis=2) # h, w, t
262280

263281
if only_y:
264-
LR = util.np2tensor(LR, bgr2rgb=False, add_batch=True) # Tensor, [CT',H',W'] or [T, H, W]
282+
# Tensor, [CT',H',W'] or [T, H, W]
283+
LR = util.np2tensor(LR, bgr2rgb=False, add_batch=True)
265284
else:
266-
LR = util.np2tensor(LR, bgr2rgb=True, add_batch=False) # Tensor, [CT',H',W'] or [T, H, W]
267-
LR = LR.view(c,t,h_LR,w_LR) # Tensor, [C,T,H,W]
268-
LR = LR.transpose(0,1) # Tensor, [T,C,H,W]
285+
# Tensor, [CT',H',W'] or [T, H, W]
286+
LR = util.np2tensor(LR, bgr2rgb=True, add_batch=False)
287+
LR = LR.view(c, t, h_LR, w_LR) # Tensor, [C,T,H,W]
288+
LR = LR.transpose(0, 1) # Tensor, [T,C,H,W]
269289
LR = LR.unsqueeze(0)
270290

271291
if only_y:
272292
# generate Cr, Cb channels using bicubic interpolation
273293
LR_bicubic = util.bgr2ycbcr(LR_bicubic, only_y=False)
274-
LR_bicubic = util.np2tensor(LR_bicubic, bgr2rgb=False, add_batch=True)
294+
LR_bicubic = util.np2tensor(
295+
LR_bicubic, bgr2rgb=False, add_batch=True)
275296
else:
276297
LR_bicubic = []
277298

278299
if len(LR.size()) == 4:
279300
b, n_frames, h_lr, w_lr = LR.size()
280-
LR = LR.view(b, -1, 1, h_lr, w_lr) # b, t, c, h, w
281-
elif len(LR.size()) == 5: #for networks that work with 3 channel images
301+
LR = LR.view(b, -1, 1, h_lr, w_lr) # b, t, c, h, w
302+
elif len(LR.size()) == 5: # for networks that work with 3 channel images
282303
_, n_frames, _, _, _ = LR.size()
283-
LR = LR # b, t, c, h, w
284-
285-
304+
LR = LR # b, t, c, h, w
286305

287306
if args.chop_forward:
288-
289307
# crop borders to ensure each patch can be divisible by 2
290308
_, _, _, h, w = LR.size()
291309
h = int(h//16) * 16
@@ -294,7 +312,7 @@ def main():
294312
if isinstance(LR_bicubic, torch.Tensor):
295313
SR_cb = LR_bicubic[:, 1, :h * scale, :w * scale]
296314
SR_cr = LR_bicubic[:, 2, :h * scale, :w * scale]
297-
315+
298316
SR_y = chop_forward(LR, model, scale).squeeze(0)
299317
if only_y:
300318
sr_img = ycbcr_to_rgb(torch.stack((SR_y, SR_cb, SR_cr), -3))
@@ -309,23 +327,24 @@ def main():
309327
SR = fake_H.detach()[0].float().cpu()
310328
if only_y:
311329
SR_cb = LR_bicubic[:, 1, :, :]
312-
SR_cr = LR_bicubic[:, 2, :, :]
330+
SR_cr = LR_bicubic[:, 2, :, :]
313331
sr_img = ycbcr_to_rgb(torch.stack((SR, SR_cb, SR_cr), -3))
314332
else:
315333
sr_img = SR
316-
334+
317335
sr_img = util.tensor2np(sr_img) # uint8
318336

319337
if not is_video:
320338
# save images
321-
cv2.imwrite(os.path.join(output_folder, os.path.basename(path)), sr_img)
339+
cv2.imwrite(os.path.join(output_folder,
340+
os.path.basename(images[idx])), sr_img)
322341
else:
323342
# Write SR frame to output video stream
324343
sr_img = cv2.cvtColor(sr_img, cv2.COLOR_BGR2RGB)
325344
process.stdin.write(
326345
sr_img
327-
.astype(np.uint8)
328-
.tobytes()
346+
.astype(np.uint8)
347+
.tobytes()
329348
)
330349

331350
# Close output stream

0 commit comments

Comments
 (0)