Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
d541997
fix: fixes for graph breaks
Abhishek-TAMU Oct 3, 2024
35b2aa6
fix: formatting
Abhishek-TAMU Oct 3, 2024
5cefb84
fix: import error
Abhishek-TAMU Oct 3, 2024
aa7b014
fix: Add Fa2Kwargs
Abhishek-TAMU Oct 7, 2024
c42deaa
Merge branch 'main' into compile_llama
Abhishek-TAMU Oct 8, 2024
926481b
fix: PR Changes
Abhishek-TAMU Oct 9, 2024
85f1330
Merge branch 'compile_llama' of https://github.com/Abhishek-TAMU/tran…
Abhishek-TAMU Oct 9, 2024
01fb377
Merge branch 'main' into compile_llama
Abhishek-TAMU Oct 9, 2024
5ec657f
Merge branch 'main' into compile_llama
Abhishek-TAMU Oct 10, 2024
20a4dd6
PR changes
Abhishek-TAMU Oct 10, 2024
045ef16
PR changes
Abhishek-TAMU Oct 10, 2024
d2796f6
PR changes
Abhishek-TAMU Oct 11, 2024
39d2868
PR changes
Abhishek-TAMU Oct 11, 2024
83747b5
Revert "PR changes"
Abhishek-TAMU Oct 11, 2024
b642d45
PR changes
Abhishek-TAMU Oct 11, 2024
d760818
Merge branch 'huggingface:main' into compile_llama
Abhishek-TAMU Oct 14, 2024
d03e673
fix: FlashAttentionKwarg
Abhishek-TAMU Oct 14, 2024
91f6fa1
Merge branch 'compile_llama' of https://github.com/Abhishek-TAMU/tran…
Abhishek-TAMU Oct 14, 2024
80e0d5f
fix: FlashAttentionKwarg
Abhishek-TAMU Oct 14, 2024
ca42b8b
PR Changes
Abhishek-TAMU Oct 15, 2024
b8d2568
PR Changes
Abhishek-TAMU Oct 15, 2024
ae11c96
PR Changes
Abhishek-TAMU Oct 15, 2024
76c51ca
PR Changes
Abhishek-TAMU Oct 15, 2024
2a69f6c
Merge branch 'huggingface:main' into compile_llama
Abhishek-TAMU Oct 15, 2024
77c7a3d
PR Changes
Abhishek-TAMU Oct 16, 2024
5333e89
Merge remote-tracking branch 'huggingface/main' into compile_llama
Abhishek-TAMU Oct 18, 2024
391715a
addition of documentation
Abhishek-TAMU Oct 18, 2024
f23c955
change in _flash_attention_forward
Abhishek-TAMU Oct 21, 2024
ba54841
Merge remote-tracking branch 'huggingface/main' into compile_llama
Abhishek-TAMU Oct 22, 2024
67c7828
make fix-copies
Abhishek-TAMU Oct 22, 2024
8d2ec29
revert make fix-copies
Abhishek-TAMU Oct 22, 2024
480c78d
Merge remote-tracking branch 'huggingface/main' into compile_llama
Abhishek-TAMU Oct 23, 2024
5a903da
fix copies
ArthurZucker Oct 23, 2024
05f9a80
style
ArthurZucker Oct 23, 2024
6843a9c
Merge branch 'main' of github.com:huggingface/transformers into compi…
ArthurZucker Oct 23, 2024
a6e2601
loss kwargs typing
ArthurZucker Oct 23, 2024
dd0bd9a
Merge branch 'main' of github.com:huggingface/transformers into compi…
ArthurZucker Oct 24, 2024
cb08b63
style and pull latest changes
ArthurZucker Oct 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions src/transformers/modeling_flash_attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,10 @@ def prepare_fa2_from_position_ids(query, key, value, position_ids):
return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length))


flash_241 = is_flash_attn_greater_or_equal("2.4.1")
deterministic_g = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
Comment on lines +183 to +184
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch this was breaking compile on my side as well



def _flash_attention_forward(
query_states: torch.Tensor,
key_states: torch.Tensor,
Expand All @@ -194,6 +198,11 @@ def _flash_attention_forward(
use_top_left_mask: bool = False,
softcap: Optional[float] = None,
deterministic: bool = None,
cu_seqlens_q: Optional[torch.LongTensor] = None,
cu_seqlens_k: Optional[torch.LongTensor] = None,
max_seqlen_in_batch_q: int = 0,
max_seqlen_in_batch_k: int = 0,
batch_size: int = 2,
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
Expand Down Expand Up @@ -232,9 +241,9 @@ def _flash_attention_forward(
)
flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}

if is_flash_attn_greater_or_equal("2.4.1"):
if flash_241:
if deterministic is None:
deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
deterministic = deterministic_g
flash_kwargs["deterministic"] = deterministic

if softcap is not None:
Expand Down Expand Up @@ -267,15 +276,7 @@ def _flash_attention_forward(
# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
# Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
elif position_ids is not None and not (torch.diff(position_ids, dim=-1) >= 0).all() and query_length != 1:
batch_size = query_states.size(0)
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
query_states, key_states, value_states, position_ids
)

cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens

elif position_ids is not None and max_seqlen_in_batch_q is not None:
attn_output = flash_attn_varlen_func(
query_states,
key_states,
Expand Down
38 changes: 38 additions & 0 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,10 @@ def forward(
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
cu_seq_lens_q: Optional[torch.LongTensor] = None,
cu_seq_lens_k: Optional[torch.LongTensor] = None,
max_length_q: int = 0,
max_length_k: int = 0,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if isinstance(past_key_value, StaticCache):
raise ValueError(
Expand Down Expand Up @@ -495,6 +499,11 @@ def forward(
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)

batch_size = query_states.size(0)
query_states = query_states.reshape(-1, query_states.size(-2), query_states.size(-1))
key_states = key_states.reshape(-1, key_states.size(-2), key_states.size(-1))
value_states = value_states.reshape(-1, value_states.size(-2), value_states.size(-1))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
batch_size = query_states.size(0)
query_states = query_states.reshape(-1, query_states.size(-2), query_states.size(-1))
key_states = key_states.reshape(-1, key_states.size(-2), key_states.size(-1))
value_states = value_states.reshape(-1, value_states.size(-2), value_states.size(-1))

I think we should do this in the _flash_attention_forward wrapper this we have 0 modeling changes, and all models will benefit easily from this!


attn_output = _flash_attention_forward(
query_states,
key_states,
Expand All @@ -506,6 +515,11 @@ def forward(
sliding_window=getattr(self, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
cu_seqlens_q=cu_seq_lens_q,
cu_seqlens_k=cu_seq_lens_k,
max_seqlen_in_batch_q=max_length_q if isinstance(max_length_q, int) else max_length_q.item(),
max_seqlen_in_batch_k=max_length_k if isinstance(max_length_k, int) else max_length_k.item(),
batch_size=batch_size,
)

attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
Expand Down Expand Up @@ -644,6 +658,10 @@ def forward(
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
cu_seq_lens_q: Optional[torch.LongTensor] = None,
cu_seq_lens_k: Optional[torch.LongTensor] = None,
max_length_q: int = 0,
max_length_k: int = 0,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Expand Down Expand Up @@ -682,6 +700,10 @@ def forward(
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
cu_seq_lens_q=cu_seq_lens_q,
cu_seq_lens_k=cu_seq_lens_k,
max_length_q=max_length_q,
max_length_k=max_length_k,
**kwargs,
)
hidden_states = residual + hidden_states
Expand Down Expand Up @@ -870,6 +892,10 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
cu_seq_lens_q: Optional[torch.LongTensor] = None,
cu_seq_lens_k: Optional[torch.LongTensor] = None,
max_length_q: int = 0,
max_length_k: int = 0,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down Expand Up @@ -953,6 +979,10 @@ def forward(
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
cu_seq_lens_q=cu_seq_lens_q,
cu_seq_lens_k=cu_seq_lens_k,
max_length_q=max_length_q,
max_length_k=max_length_k,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually something we had planned 😅 cc @gante on generate unpadding the input!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Cyrilvallez as well if you want to have fun IMO can be quite impactfull!

)

hidden_states = layer_outputs[0]
Expand Down Expand Up @@ -1148,6 +1178,10 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
cu_seq_lens_q: Optional[torch.LongTensor] = None,
cu_seq_lens_k: Optional[torch.LongTensor] = None,
max_length_q: int = 0,
max_length_k: int = 0,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these are are FlashAttention specific. IMO it would make sense to just add them as fa2_kwargs for example. We can use something like this:

class TextKwargs(TypedDict, total=False):

num_logits_to_keep: int = 0,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Expand Down Expand Up @@ -1198,6 +1232,10 @@ def forward(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
cu_seq_lens_q=cu_seq_lens_q,
cu_seq_lens_k=cu_seq_lens_k,
max_length_q=max_length_q,
max_length_k=max_length_k,
)

hidden_states = outputs[0]
Expand Down
5 changes: 4 additions & 1 deletion src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union

import numpy as np
import torch
from packaging import version

from . import __version__
Expand Down Expand Up @@ -813,7 +814,9 @@ def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding":
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs
# into a HalfTensor
if isinstance(device, str) or is_torch_device(device) or isinstance(device, int):
self.data = {k: v.to(device=device) for k, v in self.data.items() if v is not None}
self.data = {
k: v.to(device=device) for k, v in self.data.items() if v is not None and isinstance(v, torch.Tensor)
}
else:
logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.")
return self
Expand Down