Skip to content

Commit fdee8d6

Browse files
committed
[sam2] Add streaming frame loading on top of lazy loading
1 parent c98aa6b commit fdee8d6

File tree

3 files changed

+191
-29
lines changed

3 files changed

+191
-29
lines changed

sam2/sam2_video_predictor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,16 @@ def init_state(
4646
video_path,
4747
offload_video_to_cpu=False,
4848
offload_state_to_cpu=False,
49-
async_loading_frames=False,
49+
frame_load_config=None,
5050
):
5151
"""Initialize an inference state."""
52+
frame_load_config = frame_load_config or {}
5253
compute_device = self.device # device of the model
5354
images, video_height, video_width = load_video_frames(
5455
video_path=video_path,
5556
image_size=self.image_size,
5657
offload_video_to_cpu=offload_video_to_cpu,
57-
async_loading_frames=async_loading_frames,
58+
frame_load_config=frame_load_config,
5859
compute_device=compute_device,
5960
)
6061
inference_state = {}

sam2/utils/misc.py

Lines changed: 186 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,18 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import asyncio
78
import os
89
import warnings
10+
11+
from abc import abstractmethod
912
from threading import Thread
1013

1114
import numpy as np
1215
import torch
1316
from PIL import Image
1417
from tqdm import tqdm
18+
from types import LambdaType
1519

1620

1721
def get_sdpa_settings():
@@ -89,39 +93,35 @@ def mask_to_box(masks: torch.Tensor):
8993
return bbox_coords
9094

9195

92-
def _load_img_as_tensor(img_path, image_size):
93-
img_pil = Image.open(img_path)
96+
def _load_img_pil_as_tensor(img_id, img_pil, image_size):
9497
img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size)))
9598
if img_np.dtype == np.uint8: # np.uint8 is expected for JPEG images
9699
img_np = img_np / 255.0
97100
else:
98-
raise RuntimeError(f"Unknown image dtype: {img_np.dtype} on {img_path}")
101+
raise RuntimeError(f"Unknown image dtype: {img_np.dtype} on {img_id}")
99102
img = torch.from_numpy(img_np).permute(2, 0, 1)
100103
video_width, video_height = img_pil.size # the original video size
101104
return img, video_height, video_width
102105

103106

104-
class AsyncVideoFrameLoader:
107+
class LazyVideoFrameLoader:
105108
"""
106-
A list of video frames to be load asynchronously without blocking session start.
109+
Abstract class that defines primitives to load frames lazily.
107110
"""
108-
109111
def __init__(
110112
self,
111-
img_paths,
112113
image_size,
113114
offload_video_to_cpu,
114115
img_mean,
115116
img_std,
116117
compute_device,
117118
):
118-
self.img_paths = img_paths
119119
self.image_size = image_size
120120
self.offload_video_to_cpu = offload_video_to_cpu
121121
self.img_mean = img_mean
122122
self.img_std = img_std
123123
# items in `self.images` will be loaded asynchronously
124-
self.images = [None] * len(img_paths)
124+
self.images = [None] * self.__len__()
125125
# catch and raise any exceptions in the async loading thread
126126
self.exception = None
127127
# video_height and video_width be filled when loading the first image
@@ -131,18 +131,25 @@ def __init__(
131131

132132
# load the first frame to fill video_height and video_width and also
133133
# to cache it (since it's most likely where the user will click)
134-
self.__getitem__(0)
134+
self.__getitem__(self.get_first_frame_num())
135135

136-
# load the rest of frames asynchronously without blocking the session start
137-
def _load_frames():
136+
if self.should_preload():
137+
self.thread = Thread(
138+
target=self.load_frames,
139+
daemon=True,
140+
)
141+
self.thread.start()
142+
143+
def load_frames(self):
144+
asyncio.run(self.preload())
145+
146+
async def preload(self):
147+
async for index in self.get_preload_generator():
138148
try:
139-
for n in tqdm(range(len(self.images)), desc="frame loading (JPEG)"):
140-
self.__getitem__(n)
149+
self.__getitem__(index)
141150
except Exception as e:
142-
self.exception = e
143-
144-
self.thread = Thread(target=_load_frames, daemon=True)
145-
self.thread.start()
151+
if self.propagate_preload_errors():
152+
self.exception = e
146153

147154
def __getitem__(self, index):
148155
if self.exception is not None:
@@ -152,8 +159,8 @@ def __getitem__(self, index):
152159
if img is not None:
153160
return img
154161

155-
img, video_height, video_width = _load_img_as_tensor(
156-
self.img_paths[index], self.image_size
162+
img, video_height, video_width = _load_img_pil_as_tensor(
163+
self.get_image_id(index), self.load_image(index), self.image_size
157164
)
158165
self.video_height = video_height
159166
self.video_width = video_width
@@ -166,16 +173,132 @@ def __getitem__(self, index):
166173
return img
167174

168175
def __len__(self):
176+
return self.get_length()
177+
178+
@abstractmethod
179+
def get_first_frame_num(self):
180+
raise NotImplementedError
181+
182+
@abstractmethod
183+
def should_preload(self):
184+
raise NotImplementedError
185+
186+
@abstractmethod
187+
def get_preload_generator(self):
188+
raise NotImplementedError
189+
190+
@abstractmethod
191+
def propagate_preload_errors(self):
192+
raise NotImplementedError
193+
194+
@abstractmethod
195+
def load_image(self, index):
196+
raise NotImplementedError
197+
198+
@abstractmethod
199+
def get_image_id(self, index):
200+
raise NotImplementedError
201+
202+
@abstractmethod
203+
def get_length(self):
204+
raise NotImplementedError
205+
206+
207+
class AsyncVideoFrameLoader(LazyVideoFrameLoader):
208+
"""
209+
A list of video frames to be load asynchronously without blocking session start.
210+
"""
211+
212+
def __init__(
213+
self,
214+
img_paths,
215+
image_size,
216+
offload_video_to_cpu,
217+
img_mean,
218+
img_std,
219+
compute_device,
220+
):
221+
self.img_paths = img_paths
222+
LazyVideoFrameLoader.__init__(
223+
self, image_size, offload_video_to_cpu, img_mean, img_std, compute_device
224+
)
225+
226+
def get_first_frame_num(self):
227+
return 0
228+
229+
def should_preload(self):
230+
return True
231+
232+
def get_preload_generator(self):
233+
async def _available(img_paths):
234+
for i in tqdm(len(img_paths), desc="frame loading (JPEG)"):
235+
yield i
236+
237+
return _available(self.img_paths)
238+
239+
def propagate_preload_errors(self):
240+
return True
241+
242+
def load_image(self, index):
243+
return Image.load(self.img_paths[index])
244+
245+
def get_image_id(self, index):
246+
return self.img_paths[index]
247+
248+
def get_length(self):
169249
return len(self.images)
170250

171251

252+
253+
class StreamingVideoFrameLoader(LazyVideoFrameLoader):
254+
"""
255+
A list of video frames that can be loaded lazily even if they are produced after session start.
256+
"""
257+
def __init__(
258+
self,
259+
loader_func,
260+
stream_config,
261+
image_size,
262+
offload_video_to_cpu,
263+
img_mean,
264+
img_std,
265+
compute_device,
266+
):
267+
self.loader_func = loader_func
268+
self.stream_config = stream_config
269+
LazyVideoFrameLoader.__init__(
270+
self, image_size, offload_video_to_cpu, img_mean, img_std, compute_device
271+
)
272+
273+
def get_first_frame_num(self):
274+
return self.stream_config.get("first_frame_num", 0)
275+
276+
def should_preload(self):
277+
return self.stream_config.get("preload_gen", None) is not None
278+
279+
def get_preload_generator(self):
280+
return self.stream_config.get("preload_gen")
281+
282+
def propagate_preload_errors(self):
283+
return self.stream_config.get("propagate_preload_errors", True)
284+
285+
def load_image(self, index):
286+
return self.loader_func(index)
287+
288+
def get_image_id(self, index):
289+
return str(index)
290+
291+
def get_length(self):
292+
return self.stream_config.get("max_frames")
293+
294+
172295
def load_video_frames(
173296
video_path,
174297
image_size,
175298
offload_video_to_cpu,
176299
img_mean=(0.485, 0.456, 0.406),
177300
img_std=(0.229, 0.224, 0.225),
178-
async_loading_frames=False,
301+
frame_load_config=None,
179302
compute_device=torch.device("cuda"),
180303
):
181304
"""
@@ -184,6 +307,7 @@ def load_video_frames(
184307
"""
185308
is_bytes = isinstance(video_path, bytes)
186309
is_str = isinstance(video_path, str)
310+
is_func = isinstance(video_path, LambdaType)
187311
is_mp4_path = is_str and os.path.splitext(video_path)[-1] in [".mp4", ".MP4"]
188312
if is_bytes or is_mp4_path:
189313
return load_video_frames_from_video_file(
@@ -201,7 +325,18 @@ def load_video_frames(
201325
offload_video_to_cpu=offload_video_to_cpu,
202326
img_mean=img_mean,
203327
img_std=img_std,
204-
async_loading_frames=async_loading_frames,
328+
frame_load_config=frame_load_config,
329+
compute_device=compute_device,
330+
)
331+
332+
elif is_func:
333+
return load_video_frames_from_lambda(
334+
loader_func=video_path,
335+
image_size=image_size,
336+
offload_video_to_cpu=offload_video_to_cpu,
337+
img_mean=img_mean,
338+
img_std=img_std,
339+
frame_load_config=frame_load_config,
205340
compute_device=compute_device,
206341
)
207342
else:
@@ -210,13 +345,37 @@ def load_video_frames(
210345
)
211346

212347

348+
def load_video_frames_from_lambda(
349+
loader_func,
350+
image_size,
351+
offload_video_to_cpu,
352+
img_mean=(0.485, 0.456, 0.406),
353+
img_std=(0.229, 0.224, 0.225),
354+
frame_load_config=None,
355+
compute_device=torch.device("cuda"),
356+
):
357+
img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]
358+
img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None]
359+
360+
lazy_images = StreamingVideoFrameLoader(
361+
loader_func,
362+
frame_load_config,
363+
image_size,
364+
offload_video_to_cpu,
365+
img_mean,
366+
img_std,
367+
compute_device,
368+
)
369+
return lazy_images, lazy_images.video_height, lazy_images.video_width
370+
371+
213372
def load_video_frames_from_jpg_images(
214373
video_path,
215374
image_size,
216375
offload_video_to_cpu,
217376
img_mean=(0.485, 0.456, 0.406),
218377
img_std=(0.229, 0.224, 0.225),
219-
async_loading_frames=False,
378+
frame_load_config=None,
220379
compute_device=torch.device("cuda"),
221380
):
222381
"""
@@ -253,7 +412,7 @@ def load_video_frames_from_jpg_images(
253412
img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]
254413
img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None]
255414

256-
if async_loading_frames:
415+
if frame_load_config.get("async", False):
257416
lazy_images = AsyncVideoFrameLoader(
258417
img_paths,
259418
image_size,
@@ -266,7 +425,9 @@ def load_video_frames_from_jpg_images(
266425

267426
images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32)
268427
for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")):
269-
images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size)
428+
images[n], video_height, video_width = _load_img_pil_as_tensor(
429+
img_path, Image.open(img_path), image_size
430+
)
270431
if not offload_video_to_cpu:
271432
images = images.to(compute_device)
272433
img_mean = img_mean.to(compute_device)

tools/vos_inference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def vos_inference(
135135
]
136136
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
137137
inference_state = predictor.init_state(
138-
video_path=video_dir, async_loading_frames=False
138+
video_path=video_dir, frame_load_config=None
139139
)
140140
height = inference_state["video_height"]
141141
width = inference_state["video_width"]
@@ -273,7 +273,7 @@ def vos_separate_inference_per_object(
273273
]
274274
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
275275
inference_state = predictor.init_state(
276-
video_path=video_dir, async_loading_frames=False
276+
video_path=video_dir, frame_load_config=None
277277
)
278278
height = inference_state["video_height"]
279279
width = inference_state["video_width"]

0 commit comments

Comments
 (0)