-
Notifications
You must be signed in to change notification settings - Fork 31.2k
Sdpa for owlvit #42136
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Sdpa for owlvit #42136
Conversation
I ran the RUN_SLOW=1 python -m pytest tests/models/owlvit/test_modeling_owlvit.py for the original owlvit implementation and it seemed to fail the same tests as my current implementation. I'm not sure how to infer from that. |
vasqu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry but I've got to be strict about this. We no longer implement separate classes for all the attention flavors but one unified one. I think ViT is a good example in this case, e.g. see https://github.com/huggingface/transformers/blob/main/src/transformers/models/vit/modeling_vit.py
Before changing this to these standards I won't take a proper look for now.
Got it. Thanks a lot! |
d519ced to
77f5221
Compare
I made similar changes as in the vit and removed the seperate sdpa class. Let me know what you think! |
vasqu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added some comments but in general it would be best to have a green CI before requesting a review. Atm, things are likely not working as expected
| causal_attention_mask = _create_4d_causal_attention_mask( | ||
| input_shape, hidden_states.dtype, device=hidden_states.device | ||
| # OWL-ViT uses a bidirectional (non-causal) encoder. | ||
| attention_mask = create_bidirectional_mask( | ||
| config=self.config, | ||
| input_embeds=hidden_states, | ||
| attention_mask=attention_mask, | ||
| ) | ||
| # expand attention_mask | ||
| if attention_mask is not None: | ||
| # [num_samples, seq_len] -> [num_samples, 1, tgt_seq_len, src_seq_len] | ||
| attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems to suffer from the same issue as in #41750
It does not use a bidirectional mask, but a causal mask:
- The first mask is a based causal mask
- The second is a padding mask
- These are added on top creating a causal mask with padding included
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This also may need to adjust the is_causal argument dynamically as in the PR I linked - although I'm not sure if it's just causal in general
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! I made some changes to the code after referring to CLIP - removing the output_attention, return dict and casual_attention_mask. Also copied the eager attention part, attention reshaping from CLIP. Added the flash and flex attn too.
I think that the current CI is failing because the OWL VIT config file is conflicting with the current encoder implementation. Could you guide me here? Thanks a lot!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! I made some changes to the code after referring to CLIP - removing the output_attention, return dict and casual_attention_mask. Also copied the eager attention part, attention reshaping from CLIP. Added the flash and flex attn too.
I think that the current CI is failing because the OWL VIT config file is conflicting with the current encoder implementation. Could you guide me here? Thanks a lot!
Hi, I investigated the failing OwlViTForObjectDetectionTest::test_eager_matches_sdpa_inference_09_fp32_pad_left.
The failure is due to the test invoking OwlViTForObjectDetection.forward() without providing pixel_values.
OwlViTForObjectDetection requires pixel_values (image tensors) for its vision backbone. When the test omits them, the model raises a ValueError: 'pixel_values' is None.
| base_model_prefix = "owlvit" | ||
| input_modalities = ["image", "text"] | ||
| supports_gradient_checkpointing = True | ||
| _supports_sdpa = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should also support flash attn and flex attention then
|
|
||
|
|
||
| # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->OwlViT | ||
| # Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->OwlViT |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't work, output_attentions would be ignored --> see the check_model_inputs decorator in other models like CLIP
| # Eager attention implementation | ||
| # Scale query |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please follow what is done in other models, we should be able to copy the eager function from bert (see clip as example)
| # Prepare attention mask - OWL-ViT uses bidirectional (non-causal) attention | ||
| attention_mask = create_bidirectional_mask( | ||
| config=self.config, | ||
| input_embeds=hidden_states, | ||
| attention_mask=attention_mask, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The mask is prepared outside, not within the attention module
| # Get query, key, value projections | ||
| query_states = self.q_proj(hidden_states) | ||
| key_states = self.k_proj(hidden_states) | ||
| value_states = self.v_proj(hidden_states) | ||
|
|
||
| proj_shape = (bsz * self.num_heads, -1, self.head_dim) | ||
| query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) | ||
| key_states = key_states.view(*proj_shape) | ||
| value_states = value_states.view(*proj_shape) | ||
| # Reshape to (bsz, num_heads, seq_len, head_dim) | ||
| query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) | ||
| key_states = key_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) | ||
| value_states = value_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please follow llama / clip, the -1 dim is also important for example
41a76bf to
42971a3
Compare
|
[For maintainers] Suggested jobs to run (before merge) run-slow: owlv2, owlvit |
What does this PR do?
Implements SDPA for OWL VIT.
Fixes #28103
Before submitting
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@vasqu @younesbelkada