4
4
# This source code is licensed under the license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+ import asyncio
7
8
import os
8
9
import warnings
10
+
11
+ from abc import abstractmethod
9
12
from threading import Thread
10
13
11
14
import numpy as np
12
15
import torch
13
16
from PIL import Image
14
17
from tqdm import tqdm
18
+ from types import LambdaType
15
19
16
20
17
21
def get_sdpa_settings ():
@@ -89,39 +93,35 @@ def mask_to_box(masks: torch.Tensor):
89
93
return bbox_coords
90
94
91
95
92
- def _load_img_as_tensor (img_path , image_size ):
93
- img_pil = Image .open (img_path )
96
+ def _load_img_pil_as_tensor (img_id , img_pil , image_size ):
94
97
img_np = np .array (img_pil .convert ("RGB" ).resize ((image_size , image_size )))
95
98
if img_np .dtype == np .uint8 : # np.uint8 is expected for JPEG images
96
99
img_np = img_np / 255.0
97
100
else :
98
- raise RuntimeError (f"Unknown image dtype: { img_np .dtype } on { img_path } " )
101
+ raise RuntimeError (f"Unknown image dtype: { img_np .dtype } on { img_id } " )
99
102
img = torch .from_numpy (img_np ).permute (2 , 0 , 1 )
100
103
video_width , video_height = img_pil .size # the original video size
101
104
return img , video_height , video_width
102
105
103
106
104
- class AsyncVideoFrameLoader :
107
+ class LazyVideoFrameLoader :
105
108
"""
106
- A list of video frames to be load asynchronously without blocking session start .
109
+ Abstract class that defines primitives to load frames lazily .
107
110
"""
108
-
109
111
def __init__ (
110
112
self ,
111
- img_paths ,
112
113
image_size ,
113
114
offload_video_to_cpu ,
114
115
img_mean ,
115
116
img_std ,
116
117
compute_device ,
117
118
):
118
- self .img_paths = img_paths
119
119
self .image_size = image_size
120
120
self .offload_video_to_cpu = offload_video_to_cpu
121
121
self .img_mean = img_mean
122
122
self .img_std = img_std
123
123
# items in `self.images` will be loaded asynchronously
124
- self .images = [None ] * len ( img_paths )
124
+ self .images = [None ] * self . __len__ ( )
125
125
# catch and raise any exceptions in the async loading thread
126
126
self .exception = None
127
127
# video_height and video_width be filled when loading the first image
@@ -131,18 +131,25 @@ def __init__(
131
131
132
132
# load the first frame to fill video_height and video_width and also
133
133
# to cache it (since it's most likely where the user will click)
134
- self .__getitem__ (0 )
134
+ self .__getitem__ (self . get_first_frame_num () )
135
135
136
- # load the rest of frames asynchronously without blocking the session start
137
- def _load_frames ():
136
+ if self .should_preload ():
137
+ self .thread = Thread (
138
+ target = self .load_frames ,
139
+ daemon = True ,
140
+ )
141
+ self .thread .start ()
142
+
143
+ def load_frames (self ):
144
+ asyncio .run (self .preload ())
145
+
146
+ async def preload (self ):
147
+ async for index in self .get_preload_generator ():
138
148
try :
139
- for n in tqdm (range (len (self .images )), desc = "frame loading (JPEG)" ):
140
- self .__getitem__ (n )
149
+ self .__getitem__ (index )
141
150
except Exception as e :
142
- self .exception = e
143
-
144
- self .thread = Thread (target = _load_frames , daemon = True )
145
- self .thread .start ()
151
+ if self .propagate_preload_errors ():
152
+ self .exception = e
146
153
147
154
def __getitem__ (self , index ):
148
155
if self .exception is not None :
@@ -152,8 +159,8 @@ def __getitem__(self, index):
152
159
if img is not None :
153
160
return img
154
161
155
- img , video_height , video_width = _load_img_as_tensor (
156
- self .img_paths [ index ] , self .image_size
162
+ img , video_height , video_width = _load_img_pil_as_tensor (
163
+ self .get_image_id ( index ), self . load_image ( index ) , self .image_size
157
164
)
158
165
self .video_height = video_height
159
166
self .video_width = video_width
@@ -166,16 +173,132 @@ def __getitem__(self, index):
166
173
return img
167
174
168
175
def __len__ (self ):
176
+ return self .get_length ()
177
+
178
+ @abstractmethod
179
+ def get_first_frame_num (self ):
180
+ raise NotImplementedError
181
+
182
+ @abstractmethod
183
+ def should_preload (self ):
184
+ raise NotImplementedError
185
+
186
+ @abstractmethod
187
+ def get_preload_generator (self ):
188
+ raise NotImplementedError
189
+
190
+ @abstractmethod
191
+ def propagate_preload_errors (self ):
192
+ raise NotImplementedError
193
+
194
+ @abstractmethod
195
+ def load_image (self , index ):
196
+ raise NotImplementedError
197
+
198
+ @abstractmethod
199
+ def get_image_id (self , index ):
200
+ raise NotImplementedError
201
+
202
+ @abstractmethod
203
+ def get_length (self ):
204
+ raise NotImplementedError
205
+
206
+
207
+ class AsyncVideoFrameLoader (LazyVideoFrameLoader ):
208
+ """
209
+ A list of video frames to be load asynchronously without blocking session start.
210
+ """
211
+
212
+ def __init__ (
213
+ self ,
214
+ img_paths ,
215
+ image_size ,
216
+ offload_video_to_cpu ,
217
+ img_mean ,
218
+ img_std ,
219
+ compute_device ,
220
+ ):
221
+ self .img_paths = img_paths
222
+ LazyVideoFrameLoader .__init__ (
223
+ self , image_size , offload_video_to_cpu , img_mean , img_std , compute_device
224
+ )
225
+
226
+ def get_first_frame_num (self ):
227
+ return 0
228
+
229
+ def should_preload (self ):
230
+ return True
231
+
232
+ def get_preload_generator (self ):
233
+ async def _available (img_paths ):
234
+ for i in tqdm (len (img_paths ), desc = "frame loading (JPEG)" ):
235
+ yield i
236
+
237
+ return _available (self .img_paths )
238
+
239
+ def propagate_preload_errors (self ):
240
+ return True
241
+
242
+ def load_image (self , index ):
243
+ return Image .load (self .img_paths [index ])
244
+
245
+ def get_image_id (self , index ):
246
+ return self .img_paths [index ]
247
+
248
+ def get_length (self ):
169
249
return len (self .images )
170
250
171
251
252
+
253
+ class StreamingVideoFrameLoader (LazyVideoFrameLoader ):
254
+ """
255
+ A list of video frames that can be loaded lazily even if they are produced after session start.
256
+ """
257
+ def __init__ (
258
+ self ,
259
+ loader_func ,
260
+ stream_config ,
261
+ image_size ,
262
+ offload_video_to_cpu ,
263
+ img_mean ,
264
+ img_std ,
265
+ compute_device ,
266
+ ):
267
+ self .loader_func = loader_func
268
+ self .stream_config = stream_config
269
+ LazyVideoFrameLoader .__init__ (
270
+ self , image_size , offload_video_to_cpu , img_mean , img_std , compute_device
271
+ )
272
+
273
+ def get_first_frame_num (self ):
274
+ return self .stream_config .get ("first_frame_num" , 0 )
275
+
276
+ def should_preload (self ):
277
+ return self .stream_config .get ("preload_gen" , None ) is not None
278
+
279
+ def get_preload_generator (self ):
280
+ return self .stream_config .get ("preload_gen" )
281
+
282
+ def propagate_preload_errors (self ):
283
+ return self .stream_config .get ("propagate_preload_errors" , True )
284
+
285
+ def load_image (self , index ):
286
+ return self .loader_func (index )
287
+
288
+ def get_image_id (self , index ):
289
+ return str (index )
290
+
291
+ def get_length (self ):
292
+ return self .stream_config .get ("max_frames" )
293
+
294
+
172
295
def load_video_frames (
173
296
video_path ,
174
297
image_size ,
175
298
offload_video_to_cpu ,
176
299
img_mean = (0.485 , 0.456 , 0.406 ),
177
300
img_std = (0.229 , 0.224 , 0.225 ),
178
- async_loading_frames = False ,
301
+ frame_load_config = None ,
179
302
compute_device = torch .device ("cuda" ),
180
303
):
181
304
"""
@@ -184,6 +307,7 @@ def load_video_frames(
184
307
"""
185
308
is_bytes = isinstance (video_path , bytes )
186
309
is_str = isinstance (video_path , str )
310
+ is_func = isinstance (video_path , LambdaType )
187
311
is_mp4_path = is_str and os .path .splitext (video_path )[- 1 ] in [".mp4" , ".MP4" ]
188
312
if is_bytes or is_mp4_path :
189
313
return load_video_frames_from_video_file (
@@ -201,7 +325,18 @@ def load_video_frames(
201
325
offload_video_to_cpu = offload_video_to_cpu ,
202
326
img_mean = img_mean ,
203
327
img_std = img_std ,
204
- async_loading_frames = async_loading_frames ,
328
+ frame_load_config = frame_load_config ,
329
+ compute_device = compute_device ,
330
+ )
331
+
332
+ elif is_func :
333
+ return load_video_frames_from_lambda (
334
+ loader_func = video_path ,
335
+ image_size = image_size ,
336
+ offload_video_to_cpu = offload_video_to_cpu ,
337
+ img_mean = img_mean ,
338
+ img_std = img_std ,
339
+ frame_load_config = frame_load_config ,
205
340
compute_device = compute_device ,
206
341
)
207
342
else :
@@ -210,13 +345,37 @@ def load_video_frames(
210
345
)
211
346
212
347
348
+ def load_video_frames_from_lambda (
349
+ loader_func ,
350
+ image_size ,
351
+ offload_video_to_cpu ,
352
+ img_mean = (0.485 , 0.456 , 0.406 ),
353
+ img_std = (0.229 , 0.224 , 0.225 ),
354
+ frame_load_config = None ,
355
+ compute_device = torch .device ("cuda" ),
356
+ ):
357
+ img_mean = torch .tensor (img_mean , dtype = torch .float32 )[:, None , None ]
358
+ img_std = torch .tensor (img_std , dtype = torch .float32 )[:, None , None ]
359
+
360
+ lazy_images = StreamingVideoFrameLoader (
361
+ loader_func ,
362
+ frame_load_config ,
363
+ image_size ,
364
+ offload_video_to_cpu ,
365
+ img_mean ,
366
+ img_std ,
367
+ compute_device ,
368
+ )
369
+ return lazy_images , lazy_images .video_height , lazy_images .video_width
370
+
371
+
213
372
def load_video_frames_from_jpg_images (
214
373
video_path ,
215
374
image_size ,
216
375
offload_video_to_cpu ,
217
376
img_mean = (0.485 , 0.456 , 0.406 ),
218
377
img_std = (0.229 , 0.224 , 0.225 ),
219
- async_loading_frames = False ,
378
+ frame_load_config = None ,
220
379
compute_device = torch .device ("cuda" ),
221
380
):
222
381
"""
@@ -253,7 +412,7 @@ def load_video_frames_from_jpg_images(
253
412
img_mean = torch .tensor (img_mean , dtype = torch .float32 )[:, None , None ]
254
413
img_std = torch .tensor (img_std , dtype = torch .float32 )[:, None , None ]
255
414
256
- if async_loading_frames :
415
+ if frame_load_config . get ( "async" , False ) :
257
416
lazy_images = AsyncVideoFrameLoader (
258
417
img_paths ,
259
418
image_size ,
@@ -266,7 +425,9 @@ def load_video_frames_from_jpg_images(
266
425
267
426
images = torch .zeros (num_frames , 3 , image_size , image_size , dtype = torch .float32 )
268
427
for n , img_path in enumerate (tqdm (img_paths , desc = "frame loading (JPEG)" )):
269
- images [n ], video_height , video_width = _load_img_as_tensor (img_path , image_size )
428
+ images [n ], video_height , video_width = _load_img_pil_as_tensor (
429
+ img_path , Image .open (img_path ), image_size
430
+ )
270
431
if not offload_video_to_cpu :
271
432
images = images .to (compute_device )
272
433
img_mean = img_mean .to (compute_device )
0 commit comments