39
39
input_folder = os .path .normpath (args .input )
40
40
output_folder = os .path .normpath (args .output )
41
41
42
+
42
43
def chop_forward (x , model , scale , shave = 16 , min_size = 5000 , nGPUs = 1 ):
43
44
# divide into 4 patches
44
45
b , n , c , h , w = x .size ()
@@ -50,7 +51,6 @@ def chop_forward(x, model, scale, shave=16, min_size=5000, nGPUs=1):
50
51
x [:, :, :, (h - h_size ):h , 0 :w_size ],
51
52
x [:, :, :, (h - h_size ):h , (w - w_size ):w ]]
52
53
53
-
54
54
if w_size * h_size < min_size :
55
55
outputlist = []
56
56
for i in range (0 , 4 , nGPUs ):
@@ -61,7 +61,7 @@ def chop_forward(x, model, scale, shave=16, min_size=5000, nGPUs=1):
61
61
outputlist .append (output_batch .data )
62
62
else :
63
63
outputlist = [
64
- chop_forward (patch , model , scale , shave , min_size , nGPUs ) \
64
+ chop_forward (patch , model , scale , shave , min_size , nGPUs )
65
65
for patch in inputlist ]
66
66
67
67
h , w = scale * h , scale * w
@@ -76,12 +76,16 @@ def chop_forward(x, model, scale, shave=16, min_size=5000, nGPUs=1):
76
76
if len (out .shape ) < 4 :
77
77
outputlist [idx ] = out .unsqueeze (0 )
78
78
output [:, :, 0 :h_half , 0 :w_half ] = outputlist [0 ][:, :, 0 :h_half , 0 :w_half ]
79
- output [:, :, 0 :h_half , w_half :w ] = outputlist [1 ][:, :, 0 :h_half , (w_size - w + w_half ):w_size ]
80
- output [:, :, h_half :h , 0 :w_half ] = outputlist [2 ][:, :, (h_size - h + h_half ):h_size , 0 :w_half ]
81
- output [:, :, h_half :h , w_half :w ] = outputlist [3 ][:, :, (h_size - h + h_half ):h_size , (w_size - w + w_half ):w_size ]
79
+ output [:, :, 0 :h_half , w_half :w ] = outputlist [1 ][:,
80
+ :, 0 :h_half , (w_size - w + w_half ):w_size ]
81
+ output [:, :, h_half :h , 0 :w_half ] = outputlist [2 ][:,
82
+ :, (h_size - h + h_half ):h_size , 0 :w_half ]
83
+ output [:, :, h_half :h , w_half :w ] = outputlist [3 ][:, :,
84
+ (h_size - h + h_half ):h_size , (w_size - w + w_half ):w_size ]
82
85
83
86
return output .float ().cpu ()
84
87
88
+
85
89
def main ():
86
90
state_dict = torch .load (args .model )
87
91
@@ -92,7 +96,7 @@ def main():
92
96
keys = state_dict .keys ()
93
97
# ESRGAN RRDB SR net
94
98
if 'SR.model.1.sub.0.RDB1.conv1.0.weight' in keys :
95
- # extract model information
99
+ # extract model information
96
100
scale2 = 0
97
101
max_part = 0
98
102
for part in list (state_dict ):
@@ -113,13 +117,14 @@ def main():
113
117
nf = state_dict ['SR.model.0.weight' ].shape [0 ]
114
118
115
119
if scale == 2 :
116
- if state_dict ['OFR.SR.1.weight' ].shape [0 ] == 576 :
120
+ if state_dict ['OFR.SR.1.weight' ].shape [0 ] == 576 :
117
121
scale = 3
118
122
119
123
frame_size = state_dict ['SR.model.0.weight' ].shape [1 ]
120
124
num_frames = (((frame_size - 3 ) // (3 * (scale ** 2 ))) + 1 )
121
125
122
- model = SOFVSR .SOFVSR (scale = scale , n_frames = num_frames , channels = num_channels , SR_net = 'rrdb' , sr_nf = nf , sr_nb = nb , img_ch = 3 , sr_gaussian_noise = False )
126
+ model = SOFVSR .SOFVSR (scale = scale , n_frames = num_frames , channels = num_channels ,
127
+ SR_net = 'rrdb' , sr_nf = nf , sr_nb = nb , img_ch = 3 , sr_gaussian_noise = False )
123
128
only_y = False
124
129
# Default SOFVSR SR net
125
130
else :
@@ -138,7 +143,8 @@ def main():
138
143
# Extract num_frames from model
139
144
frame_size = state_dict ['SR.body.0.weight' ].shape [1 ]
140
145
num_frames = (((frame_size - 1 ) // scale ** 2 ) + 1 )
141
- model = SOFVSR .SOFVSR (scale = scale , n_frames = num_frames , channels = num_channels , SR_net = 'sofvsr' , img_ch = 1 )
146
+ model = SOFVSR .SOFVSR (scale = scale , n_frames = num_frames ,
147
+ channels = num_channels , SR_net = 'sofvsr' , img_ch = 1 )
142
148
only_y = True
143
149
144
150
# Create model
@@ -151,10 +157,12 @@ def main():
151
157
152
158
# Grabs video metadata information
153
159
probe = ffmpeg .probe (args .input )
154
- video_stream = next ((stream for stream in probe ['streams' ] if stream ['codec_type' ] == 'video' ), None )
160
+ video_stream = next (
161
+ (stream for stream in probe ['streams' ] if stream ['codec_type' ] == 'video' ), None )
155
162
width = int (video_stream ['width' ])
156
163
height = int (video_stream ['height' ])
157
- framerate = int (video_stream ['r_frame_rate' ].split ('/' )[0 ]) / int (video_stream ['r_frame_rate' ].split ('/' )[1 ])
164
+ framerate = int (video_stream ['r_frame_rate' ].split (
165
+ '/' )[0 ]) / int (video_stream ['r_frame_rate' ].split ('/' )[1 ])
158
166
vcodec = 'libx264'
159
167
crf = args .crf
160
168
@@ -172,7 +180,7 @@ def main():
172
180
.reshape ([- 1 , height , width , 3 ])
173
181
)
174
182
175
- # Convert numpy array into frame list
183
+ # Convert numpy array into frame list
176
184
images = []
177
185
for i in range (video .shape [0 ]):
178
186
frame = cv2 .cvtColor (video [i ], cv2 .COLOR_RGB2BGR )
@@ -181,10 +189,10 @@ def main():
181
189
# Open output file writer
182
190
process = (
183
191
ffmpeg
184
- .input ('pipe:' , format = 'rawvideo' , pix_fmt = 'rgb24' , s = '{}x{}' .format (width * scale , height * scale ))
185
- .output (args .output , pix_fmt = 'yuv420p' , vcodec = vcodec , r = framerate , crf = crf , preset = 'veryfast' )
186
- .overwrite_output ()
187
- .run_async (pipe_stdin = True )
192
+ .input ('pipe:' , format = 'rawvideo' , pix_fmt = 'rgb24' , s = '{}x{}' .format (width * scale , height * scale ))
193
+ .output (args .output , pix_fmt = 'yuv420p' , vcodec = vcodec , r = framerate , crf = crf , preset = 'veryfast' )
194
+ .overwrite_output ()
195
+ .run_async (pipe_stdin = True )
188
196
)
189
197
# Regular case with input/output frame images
190
198
else :
@@ -193,99 +201,109 @@ def main():
193
201
for file in sorted (files ):
194
202
if file .split ('.' )[- 1 ].lower () in ['png' , 'jpg' , 'jpeg' , 'gif' , 'bmp' , 'tiff' , 'tga' ]:
195
203
images .append (os .path .join (root , file ))
196
-
204
+
197
205
# Pad beginning and end frames so they get included in output
198
- images .insert (0 , images [0 ])
199
- images .append (images [- 1 ])
206
+ num_padding = (num_frames - 1 ) // 2
207
+ for _ in range (num_padding ):
208
+ images .insert (0 , images [0 ])
209
+ images .append (images [- 1 ])
210
+
211
+ previous_lr_list = []
200
212
201
213
# Inference loop
202
- for idx , path in enumerate (images [1 :- 1 ], 0 ):
203
- idx_center = (num_frames - 1 ) // 2
204
- idx_frame = idx
205
-
214
+ for idx in range (num_padding , len (images ) - num_padding ):
215
+
206
216
# Only print this if processing frames
207
217
if not is_video :
208
- img_name = os .path .splitext (os .path .basename (path ))[0 ]
209
- print (idx_frame , img_name )
210
-
211
- # read LR frames
212
- LR_list = []
213
- LR_bicubic = None
214
- for i_frame in range (num_frames ):
215
- # Last and second to last frames
216
- if idx == len (images )- 2 and num_frames == 3 :
217
- # print("second to last frame:", i_frame)
218
- if i_frame == 0 :
219
- LR_img = images [idx ] if is_video else cv2 .imread (images [idx_frame ], cv2 .IMREAD_COLOR )
220
- else :
221
- LR_img = images [idx + 1 ] if is_video else cv2 .imread (images [idx_frame + 1 ], cv2 .IMREAD_COLOR )
222
- elif idx == len (images )- 1 and num_frames == 3 :
223
- # print("last frame:", i_frame)
224
- LR_img = images [idx ] if is_video else cv2 .imread (images [idx_frame ], cv2 .IMREAD_COLOR )
225
- # Every other internal frame
226
- else :
227
- # print("normal frame:", idx_frame)
228
- LR_img = images [idx + i_frame ] if is_video else cv2 .imread (images [idx_frame + i_frame ], cv2 .IMREAD_COLOR )
229
-
218
+ img_name = os .path .splitext (os .path .basename (images [idx ]))[0 ]
219
+ print (idx - num_padding , img_name )
220
+
221
+ # First pass
222
+ if idx == num_padding :
223
+ LR_list = []
224
+ LR_bicubic = None
225
+ # Load all beginning images on either side of current index
226
+ # E.g. num_frames = 7, from -3 to 3
227
+ for i in range (- num_padding , num_padding + 1 ):
228
+ # Read image or select video frame
229
+ LR_img = images [idx + i ] if is_video else cv2 .imread (
230
+ images [idx + i ], cv2 .IMREAD_COLOR )
231
+ if not only_y :
232
+ # TODO: Figure out why this is necessary
233
+ LR_img = cv2 .cvtColor (LR_img , cv2 .COLOR_BGR2RGB )
234
+ LR_list .append (LR_img )
235
+ # Other passes
236
+ else :
237
+ # Remove beginning frame from cached list
238
+ LR_list = previous_lr_list [1 :]
239
+ # Load next image or video frame
240
+ new_img = images [idx + num_padding ] if is_video else cv2 .imread (
241
+ images [idx + num_padding ], cv2 .IMREAD_COLOR )
230
242
if not only_y :
231
- LR_img = cv2 .cvtColor (LR_img , cv2 .COLOR_BGR2RGB )
232
-
233
- # get the bicubic upscale of the center frame to concatenate for SR
234
- if only_y and i_frame == idx_center :
235
- if args .denoise :
236
- LR_bicubic = cv2 .blur (LR_img , (3 ,3 ))
237
- else :
238
- LR_bicubic = LR_img
239
- LR_bicubic = util .imresize_np (img = LR_bicubic , scale = scale ) # bicubic upscale
240
-
241
- if only_y :
242
- # extract Y channel from frames
243
- # normal path, only Y for both
244
- LR_img = util .bgr2ycbcr (LR_img , only_y = True )
245
-
246
- # expand Y images to add the channel dimension
247
- # normal path, only Y for both
248
- LR_img = util .fix_img_channels (LR_img , 1 )
243
+ # TODO: Figure out why this is necessary
244
+ new_img = cv2 .cvtColor (LR_img , cv2 .COLOR_BGR2RGB )
245
+ LR_list .append (new_img )
246
+ # Cache current list for next iter
247
+ previous_lr_list = LR_list
249
248
250
- LR_list .append (LR_img ) # h, w, c
249
+ # Convert LR_list to grayscale
250
+ if only_y :
251
+ gray_lr_list = []
252
+ LR_bicubic = LR_list [num_padding ]
253
+ for i in range (len (LR_list )):
254
+ gray_lr = util .bgr2ycbcr (LR_list [i ], only_y = True )
255
+ gray_lr = util .fix_img_channels (gray_lr , 1 )
256
+ gray_lr_list .append (gray_lr )
257
+ LR_list = gray_lr_list
258
+
259
+ # Get the bicubic upscale of the center frame to concatenate for SR
260
+ if only_y :
261
+ if args .denoise :
262
+ LR_bicubic = cv2 .blur (LR_bicubic , (3 , 3 ))
263
+ else :
264
+ LR_bicubic = LR_bicubic
265
+ LR_bicubic = util .imresize_np (
266
+ img = LR_bicubic , scale = scale ) # bicubic upscale
251
267
252
- if not only_y :
253
- h_LR , w_LR , c = LR_img .shape
268
+ if not only_y :
269
+ h_LR , w_LR , c = LR_list [ 0 ] .shape
254
270
255
271
if not only_y :
256
272
t = num_frames
257
- LR = [np .asarray (LT ) for LT in LR_list ] # list -> numpy # input: list (contatin numpy: [H,W,C])
258
- LR = np .asarray (LR ) # numpy, [T,H,W,C]
259
- LR = LR .transpose (1 ,2 ,3 ,0 ).reshape (h_LR , w_LR , - 1 ) # numpy, [Hl',Wl',CT]
273
+ # list -> numpy # input: list (contatin numpy: [H,W,C])
274
+ LR = [np .asarray (LT ) for LT in LR_list ]
275
+ LR = np .asarray (LR ) # numpy, [T,H,W,C]
276
+ LR = LR .transpose (1 , 2 , 3 , 0 ).reshape (
277
+ h_LR , w_LR , - 1 ) # numpy, [Hl',Wl',CT]
260
278
else :
261
- LR = np .concatenate ((LR_list ), axis = 2 ) # h, w, t
279
+ LR = np .concatenate ((LR_list ), axis = 2 ) # h, w, t
262
280
263
281
if only_y :
264
- LR = util .np2tensor (LR , bgr2rgb = False , add_batch = True ) # Tensor, [CT',H',W'] or [T, H, W]
282
+ # Tensor, [CT',H',W'] or [T, H, W]
283
+ LR = util .np2tensor (LR , bgr2rgb = False , add_batch = True )
265
284
else :
266
- LR = util .np2tensor (LR , bgr2rgb = True , add_batch = False ) # Tensor, [CT',H',W'] or [T, H, W]
267
- LR = LR .view (c ,t ,h_LR ,w_LR ) # Tensor, [C,T,H,W]
268
- LR = LR .transpose (0 ,1 ) # Tensor, [T,C,H,W]
285
+ # Tensor, [CT',H',W'] or [T, H, W]
286
+ LR = util .np2tensor (LR , bgr2rgb = True , add_batch = False )
287
+ LR = LR .view (c , t , h_LR , w_LR ) # Tensor, [C,T,H,W]
288
+ LR = LR .transpose (0 , 1 ) # Tensor, [T,C,H,W]
269
289
LR = LR .unsqueeze (0 )
270
290
271
291
if only_y :
272
292
# generate Cr, Cb channels using bicubic interpolation
273
293
LR_bicubic = util .bgr2ycbcr (LR_bicubic , only_y = False )
274
- LR_bicubic = util .np2tensor (LR_bicubic , bgr2rgb = False , add_batch = True )
294
+ LR_bicubic = util .np2tensor (
295
+ LR_bicubic , bgr2rgb = False , add_batch = True )
275
296
else :
276
297
LR_bicubic = []
277
298
278
299
if len (LR .size ()) == 4 :
279
300
b , n_frames , h_lr , w_lr = LR .size ()
280
- LR = LR .view (b , - 1 , 1 , h_lr , w_lr ) # b, t, c, h, w
281
- elif len (LR .size ()) == 5 : # for networks that work with 3 channel images
301
+ LR = LR .view (b , - 1 , 1 , h_lr , w_lr ) # b, t, c, h, w
302
+ elif len (LR .size ()) == 5 : # for networks that work with 3 channel images
282
303
_ , n_frames , _ , _ , _ = LR .size ()
283
- LR = LR # b, t, c, h, w
284
-
285
-
304
+ LR = LR # b, t, c, h, w
286
305
287
306
if args .chop_forward :
288
-
289
307
# crop borders to ensure each patch can be divisible by 2
290
308
_ , _ , _ , h , w = LR .size ()
291
309
h = int (h // 16 ) * 16
@@ -294,7 +312,7 @@ def main():
294
312
if isinstance (LR_bicubic , torch .Tensor ):
295
313
SR_cb = LR_bicubic [:, 1 , :h * scale , :w * scale ]
296
314
SR_cr = LR_bicubic [:, 2 , :h * scale , :w * scale ]
297
-
315
+
298
316
SR_y = chop_forward (LR , model , scale ).squeeze (0 )
299
317
if only_y :
300
318
sr_img = ycbcr_to_rgb (torch .stack ((SR_y , SR_cb , SR_cr ), - 3 ))
@@ -309,23 +327,24 @@ def main():
309
327
SR = fake_H .detach ()[0 ].float ().cpu ()
310
328
if only_y :
311
329
SR_cb = LR_bicubic [:, 1 , :, :]
312
- SR_cr = LR_bicubic [:, 2 , :, :]
330
+ SR_cr = LR_bicubic [:, 2 , :, :]
313
331
sr_img = ycbcr_to_rgb (torch .stack ((SR , SR_cb , SR_cr ), - 3 ))
314
332
else :
315
333
sr_img = SR
316
-
334
+
317
335
sr_img = util .tensor2np (sr_img ) # uint8
318
336
319
337
if not is_video :
320
338
# save images
321
- cv2 .imwrite (os .path .join (output_folder , os .path .basename (path )), sr_img )
339
+ cv2 .imwrite (os .path .join (output_folder ,
340
+ os .path .basename (images [idx ])), sr_img )
322
341
else :
323
342
# Write SR frame to output video stream
324
343
sr_img = cv2 .cvtColor (sr_img , cv2 .COLOR_BGR2RGB )
325
344
process .stdin .write (
326
345
sr_img
327
- .astype (np .uint8 )
328
- .tobytes ()
346
+ .astype (np .uint8 )
347
+ .tobytes ()
329
348
)
330
349
331
350
# Close output stream
0 commit comments