Skip to content

Conversation

@Guppy16
Copy link

@Guppy16 Guppy16 commented Jan 13, 2026

What does this PR do?

Fix SAM2 Video inference processor so that it can support float16 (currently just works for fp32 and bfloat16).

How to reproduce

Demo source from here

  • This demo will work for: dtype = torch.bfloat16 and dtype = torch.float32,
  • and this PR fixes it for the case: dtype = torch.float16
  • (pls note that fp8 / int8 / etc don't work)
import torch
from transformers import Sam2VideoModel, Sam2VideoProcessor
from transformers.video_utils import load_video

device = torch.device("cuda")
dtype = torch.float16
model_name = "facebook/sam2.1-hiera-tiny"

model = Sam2VideoModel.from_pretrained(model_name).to(device, dtype=dtype)

processor = Sam2VideoProcessor.from_pretrained(model_name)


video_url = "https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/bedroom.mp4"

video_frames, _ = load_video(video_url)

# Initialize session for streaming
inference_session = processor.init_video_session(
    inference_device=device,
    dtype=dtype,
)

# Process frames one by one
for frame_idx, frame in enumerate(video_frames[:10]):  # Process first 10 frames
    inputs = processor(images=frame, device=device, return_tensors="pt")
    if frame_idx == 0:
        # Add point input on first frame
        processor.add_inputs_to_inference_session(
            inference_session=inference_session,
            frame_idx=0,
            obj_ids=1,
            input_points=[[[[210, 350], [250, 220]]]],
            input_labels=[[[1, 1]]],
            original_size=inputs.original_sizes[
                0
            ],  # need to be provided when using streaming video inference
        )
    # Process current frame
    sam2_video_output = model(
        inference_session=inference_session, frame=inputs.pixel_values[0]
    )
    video_res_masks = processor.post_process_masks(
        [sam2_video_output.pred_masks],
        original_sizes=inputs.original_sizes,
        binarize=False,
    )[0]
    print(f"Frame {frame_idx}: mask shape {video_res_masks.shape}")

Who can review?

@yonigozlan @molbap

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: sam2_video

@Guppy16
Copy link
Author

Guppy16 commented Jan 16, 2026

@yonigozlan bump. there are a few cicd pipelines which are failing; smth to do with code quality and consistency between sam 2 and sam 3 (but not sure entirely)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant