-
Notifications
You must be signed in to change notification settings - Fork 30.9k
Remove graph breaks for torch.compile() in flash_attention_forward when Lllama Model is padding free tuned #33932
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
d541997
35b2aa6
5cefb84
aa7b014
c42deaa
926481b
85f1330
01fb377
5ec657f
20a4dd6
045ef16
d2796f6
39d2868
83747b5
b642d45
d760818
d03e673
91f6fa1
80e0d5f
ca42b8b
b8d2568
ae11c96
76c51ca
2a69f6c
77c7a3d
5333e89
391715a
f23c955
ba54841
67c7828
8d2ec29
480c78d
5a903da
05f9a80
6843a9c
a6e2601
dd0bd9a
cb08b63
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -422,6 +422,10 @@ def forward( | |||||||||
use_cache: bool = False, | ||||||||||
cache_position: Optional[torch.LongTensor] = None, | ||||||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 | ||||||||||
cu_seq_lens_q: Optional[torch.LongTensor] = None, | ||||||||||
cu_seq_lens_k: Optional[torch.LongTensor] = None, | ||||||||||
max_length_q: int = 0, | ||||||||||
max_length_k: int = 0, | ||||||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | ||||||||||
if isinstance(past_key_value, StaticCache): | ||||||||||
raise ValueError( | ||||||||||
|
@@ -495,6 +499,11 @@ def forward( | |||||||||
key_states = key_states.to(target_dtype) | ||||||||||
value_states = value_states.to(target_dtype) | ||||||||||
|
||||||||||
batch_size = query_states.size(0) | ||||||||||
query_states = query_states.reshape(-1, query_states.size(-2), query_states.size(-1)) | ||||||||||
key_states = key_states.reshape(-1, key_states.size(-2), key_states.size(-1)) | ||||||||||
value_states = value_states.reshape(-1, value_states.size(-2), value_states.size(-1)) | ||||||||||
|
batch_size = query_states.size(0) | |
query_states = query_states.reshape(-1, query_states.size(-2), query_states.size(-1)) | |
key_states = key_states.reshape(-1, key_states.size(-2), key_states.size(-1)) | |
value_states = value_states.reshape(-1, value_states.size(-2), value_states.size(-1)) |
I think we should do this in the _flash_attention_forward
wrapper this we have 0 modeling changes, and all models will benefit easily from this!
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually something we had planned 😅 cc @gante on generate unpadding the input!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Cyrilvallez as well if you want to have fun IMO can be quite impactfull!
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these are are FlashAttention specific. IMO it would make sense to just add them as fa2_kwargs
for example. We can use something like this:
class TextKwargs(TypedDict, total=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice catch this was breaking compile on my side as well