Skip to content

Conversation

@Aravind-11
Copy link
Contributor

@Aravind-11 Aravind-11 commented Nov 10, 2025

What does this PR do?

Implements SDPA for OWL VIT.

Fixes #28103

Before submitting

Who can review?

@vasqu @younesbelkada

@Aravind-11
Copy link
Contributor Author

What does this PR do?

Implements SDPA for OWL VIT. Revamp of #28818

Fixes #28103

Before submitting

Who can review?

@vasqu @younesbelkada

I ran the RUN_SLOW=1 python -m pytest tests/models/owlvit/test_modeling_owlvit.py for the original owlvit implementation and it seemed to fail the same tests as my current implementation. I'm not sure how to infer from that.

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry but I've got to be strict about this. We no longer implement separate classes for all the attention flavors but one unified one. I think ViT is a good example in this case, e.g. see https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/modeling_vit.py

Before changing this to these standards I won't take a proper look for now.

@Aravind-11
Copy link
Contributor Author

Sorry but I've got to be strict about this. We no longer implement separate classes for all the attention flavors but one unified one. I think ViT is a good example in this case, e.g. see https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/modeling_vit.py

Before changing this to these standards I won't take a proper look for now.

Got it. Thanks a lot!

@Aravind-11
Copy link
Contributor Author

Sorry but I've got to be strict about this. We no longer implement separate classes for all the attention flavors but one unified one. I think ViT is a good example in this case, e.g. see https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/modeling_vit.py

Before changing this to these standards I won't take a proper look for now.

I made similar changes as in the vit and removed the seperate sdpa class. Let me know what you think!

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some comments but in general it would be best to have a green CI before requesting a review. Atm, things are likely not working as expected

Comment on lines 716 to 724
causal_attention_mask = _create_4d_causal_attention_mask(
input_shape, hidden_states.dtype, device=hidden_states.device
# OWL-ViT uses a bidirectional (non-causal) encoder.
attention_mask = create_bidirectional_mask(
config=self.config,
input_embeds=hidden_states,
attention_mask=attention_mask,
)
# expand attention_mask
if attention_mask is not None:
# [num_samples, seq_len] -> [num_samples, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to suffer from the same issue as in #41750

It does not use a bidirectional mask, but a causal mask:

  • The first mask is a based causal mask
  • The second is a padding mask
  • These are added on top creating a causal mask with padding included

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also may need to adjust the is_causal argument dynamically as in the PR I linked - although I'm not sure if it's just causal in general

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I made some changes to the code after referring to CLIP - removing the output_attention, return dict and casual_attention_mask. Also copied the eager attention part, attention reshaping from CLIP. Added the flash and flex attn too.

I think that the current CI is failing because the OWL VIT config file is conflicting with the current encoder implementation. Could you guide me here? Thanks a lot!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I made some changes to the code after referring to CLIP - removing the output_attention, return dict and casual_attention_mask. Also copied the eager attention part, attention reshaping from CLIP. Added the flash and flex attn too.

I think that the current CI is failing because the OWL VIT config file is conflicting with the current encoder implementation. Could you guide me here? Thanks a lot!

Hi, I investigated the failing OwlViTForObjectDetectionTest::test_eager_matches_sdpa_inference_09_fp32_pad_left.

The failure is due to the test invoking OwlViTForObjectDetection.forward() without providing pixel_values.

OwlViTForObjectDetection requires pixel_values (image tensors) for its vision backbone. When the test omits them, the model raises a ValueError: 'pixel_values' is None.

base_model_prefix = "owlvit"
input_modalities = ["image", "text"]
supports_gradient_checkpointing = True
_supports_sdpa = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should also support flash attn and flex attention then



# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->OwlViT
# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->OwlViT
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't work, output_attentions would be ignored --> see the check_model_inputs decorator in other models like CLIP

Comment on lines 440 to 441
# Eager attention implementation
# Scale query
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please follow what is done in other models, we should be able to copy the eager function from bert (see clip as example)

Comment on lines 419 to 424
# Prepare attention mask - OWL-ViT uses bidirectional (non-causal) attention
attention_mask = create_bidirectional_mask(
config=self.config,
input_embeds=hidden_states,
attention_mask=attention_mask,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mask is prepared outside, not within the attention module

Comment on lines 409 to 417
# Get query, key, value projections
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
# Reshape to (bsz, num_heads, seq_len, head_dim)
query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please follow llama / clip, the -1 dim is also important for example

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: owlv2, owlvit

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

OWL-VIT Vision Foundation Model deployment in the edge cases - Need SDPA support for OWL-ViT Model optimization for Edge Deployment

3 participants