-
Notifications
You must be signed in to change notification settings - Fork 30.9k
Remove graph breaks for torch.compile() in flash_attention_forward when Lllama Model is padding free tuned #33932
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
d541997
35b2aa6
5cefb84
aa7b014
c42deaa
926481b
85f1330
01fb377
5ec657f
20a4dd6
045ef16
d2796f6
39d2868
83747b5
b642d45
d760818
d03e673
91f6fa1
80e0d5f
ca42b8b
b8d2568
ae11c96
76c51ca
2a69f6c
77c7a3d
5333e89
391715a
f23c955
ba54841
67c7828
8d2ec29
480c78d
5a903da
05f9a80
6843a9c
a6e2601
dd0bd9a
cb08b63
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -180,6 +180,10 @@ def prepare_fa2_from_position_ids(query, key, value, position_ids): | |
return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length)) | ||
|
||
|
||
flash_241 = is_flash_attn_greater_or_equal("2.4.1") | ||
deterministic_g = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" | ||
|
||
|
||
def _flash_attention_forward( | ||
query_states: torch.Tensor, | ||
key_states: torch.Tensor, | ||
|
@@ -194,6 +198,11 @@ def _flash_attention_forward( | |
use_top_left_mask: bool = False, | ||
softcap: Optional[float] = None, | ||
deterministic: bool = None, | ||
cu_seq_lens_q: Optional[torch.LongTensor] = None, | ||
cu_seq_lens_k: Optional[torch.LongTensor] = None, | ||
max_length_q: int = 0, | ||
max_length_k: int = 0, | ||
batch_size: int = 2, | ||
): | ||
""" | ||
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token | ||
|
@@ -232,9 +241,9 @@ def _flash_attention_forward( | |
) | ||
flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {} | ||
|
||
if is_flash_attn_greater_or_equal("2.4.1"): | ||
if flash_241: | ||
if deterministic is None: | ||
deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" | ||
deterministic = deterministic_g | ||
flash_kwargs["deterministic"] = deterministic | ||
|
||
if softcap is not None: | ||
|
@@ -267,24 +276,15 @@ def _flash_attention_forward( | |
# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing | ||
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage. | ||
# Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach | ||
# Note: the `torch.diff(...)` condition is last to use short-circuit and avoid the cuda synchronization it incurs during inference (query_length == 1 always) | ||
elif position_ids is not None and query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all(): | ||
batch_size = query_states.size(0) | ||
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids( | ||
query_states, key_states, value_states, position_ids | ||
) | ||
|
||
cu_seqlens_q, cu_seqlens_k = cu_seq_lens | ||
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should keep this in case cu_seq_lens_q/k and etc are not passsed to compute them! |
||
|
||
elif position_ids is not None and max_length_q is not None: | ||
attn_output = flash_attn_varlen_func( | ||
query_states, | ||
key_states, | ||
value_states, | ||
cu_seqlens_q=cu_seqlens_q, | ||
cu_seqlens_k=cu_seqlens_k, | ||
max_seqlen_q=max_seqlen_in_batch_q, | ||
max_seqlen_k=max_seqlen_in_batch_k, | ||
cu_seqlens_q=cu_seq_lens_q, | ||
cu_seqlens_k=cu_seq_lens_k, | ||
max_seqlen_q=max_length_q, | ||
max_seqlen_k=max_length_k, | ||
dropout_p=dropout, | ||
softmax_scale=softmax_scale, | ||
causal=causal, | ||
|
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -48,6 +48,9 @@ | |||||||||
logging, | ||||||||||
replace_return_docstrings, | ||||||||||
) | ||||||||||
from ...processing_utils import ( | ||||||||||
Fa2Kwargs, | ||||||||||
) | ||||||||||
from .configuration_llama import LlamaConfig | ||||||||||
|
||||||||||
|
||||||||||
|
@@ -421,6 +424,7 @@ def forward( | |||||||||
use_cache: bool = False, | ||||||||||
cache_position: Optional[torch.LongTensor] = None, | ||||||||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 | ||||||||||
**kwargs, | ||||||||||
|
**kwargs, | |
**kwargs: Unpack[Fa2Kwargs], |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
batch_size = query_states.size(0) | |
query_states = query_states.reshape(-1, query_states.size(-2), query_states.size(-1)) | |
key_states = key_states.reshape(-1, key_states.size(-2), key_states.size(-1)) | |
value_states = value_states.reshape(-1, value_states.size(-2), value_states.size(-1)) |
I think we should do this in the _flash_attention_forward
wrapper this we have 0 modeling changes, and all models will benefit easily from this!
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
**fa2_kwargs: Fa2Kwargs, | |
**fa2_kwargs: Unpack[Fa2Kwargs], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice catch this was breaking compile on my side as well