Skip to content

Commit 1854f5f

Browse files
committed
Remove duplicate imports, return tensor
1 parent 6ad479c commit 1854f5f

File tree

1 file changed

+17
-9
lines changed

1 file changed

+17
-9
lines changed

benchmarks/decoders/benchmark_decoders_library.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -150,19 +150,19 @@ class OpenCVDecoder(AbstractDecoder):
150150
def __init__(self, backend):
151151
import cv2
152152

153+
self.cv2 = cv2
154+
153155
self._available_backends = {"FFMPEG": cv2.CAP_FFMPEG}
154156
self._backend = self._available_backends.get(backend)
155157

156158
self._print_each_iteration_time = False
157159

158160
def decode_frames(self, video_file, pts_list):
159-
import cv2
160-
161-
cap = cv2.VideoCapture(video_file, self._backend)
161+
cap = self.cv2.VideoCapture(video_file, self._backend)
162162
if not cap.isOpened():
163163
raise ValueError("Could not open video stream")
164164

165-
fps = cap.get(cv2.CAP_PROP_FPS)
165+
fps = cap.get(self.cv2.CAP_PROP_FPS)
166166
approx_frame_indices = [int(pts * fps) for pts in pts_list]
167167

168168
current_frame = 0
@@ -174,6 +174,11 @@ def decode_frames(self, video_file, pts_list):
174174
if current_frame in approx_frame_indices: # only decompress needed
175175
ret, frame = cap.retrieve()
176176
if ret:
177+
# OpenCV uses BGR, change to RGB
178+
frame = self.cv2.cvtColor(frame, self.cv2.COLOR_BGR2RGB)
179+
# Update to C, H, W
180+
frame = np.transpose(frame, (2, 0, 1))
181+
frame = torch.from_numpy(frame)
177182
frames.append(frame)
178183

179184
if len(frames) == len(approx_frame_indices):
@@ -184,9 +189,7 @@ def decode_frames(self, video_file, pts_list):
184189
return frames
185190

186191
def decode_first_n_frames(self, video_file, n):
187-
import cv2
188-
189-
cap = cv2.VideoCapture(video_file, self._backend)
192+
cap = self.cv2.VideoCapture(video_file, self._backend)
190193
if not cap.isOpened():
191194
raise ValueError("Could not open video stream")
192195

@@ -197,16 +200,21 @@ def decode_first_n_frames(self, video_file, n):
197200
raise ValueError("Could not grab video frame")
198201
ret, frame = cap.retrieve()
199202
if ret:
203+
# OpenCV uses BGR, change to RGB
204+
frame = self.cv2.cvtColor(frame, self.cv2.COLOR_BGR2RGB)
205+
# Update to C, H, W
206+
frame = np.transpose(frame, (2, 0, 1))
207+
frame = torch.from_numpy(frame)
200208
frames.append(frame)
201209
cap.release()
202210
assert len(frames) == n
203211
return frames
204212

205213
def decode_and_resize(self, video_file, pts_list, height, width, device):
206-
import cv2
207214

215+
# OpenCV doesn't apply antialias, while other `decode_and_resize()` implementations apply antialias by default.
208216
frames = [
209-
cv2.resize(frame, (width, height))
217+
self.cv2.resize(frame, (width, height))
210218
for frame in self.decode_frames(video_file, pts_list)
211219
]
212220
return frames

0 commit comments

Comments
 (0)