Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
d541997
fix: fixes for graph breaks
Abhishek-TAMU Oct 3, 2024
35b2aa6
fix: formatting
Abhishek-TAMU Oct 3, 2024
5cefb84
fix: import error
Abhishek-TAMU Oct 3, 2024
aa7b014
fix: Add Fa2Kwargs
Abhishek-TAMU Oct 7, 2024
c42deaa
Merge branch 'main' into compile_llama
Abhishek-TAMU Oct 8, 2024
926481b
fix: PR Changes
Abhishek-TAMU Oct 9, 2024
85f1330
Merge branch 'compile_llama' of https://github.com/Abhishek-TAMU/tran…
Abhishek-TAMU Oct 9, 2024
01fb377
Merge branch 'main' into compile_llama
Abhishek-TAMU Oct 9, 2024
5ec657f
Merge branch 'main' into compile_llama
Abhishek-TAMU Oct 10, 2024
20a4dd6
PR changes
Abhishek-TAMU Oct 10, 2024
045ef16
PR changes
Abhishek-TAMU Oct 10, 2024
d2796f6
PR changes
Abhishek-TAMU Oct 11, 2024
39d2868
PR changes
Abhishek-TAMU Oct 11, 2024
83747b5
Revert "PR changes"
Abhishek-TAMU Oct 11, 2024
b642d45
PR changes
Abhishek-TAMU Oct 11, 2024
d760818
Merge branch 'huggingface:main' into compile_llama
Abhishek-TAMU Oct 14, 2024
d03e673
fix: FlashAttentionKwarg
Abhishek-TAMU Oct 14, 2024
91f6fa1
Merge branch 'compile_llama' of https://github.com/Abhishek-TAMU/tran…
Abhishek-TAMU Oct 14, 2024
80e0d5f
fix: FlashAttentionKwarg
Abhishek-TAMU Oct 14, 2024
ca42b8b
PR Changes
Abhishek-TAMU Oct 15, 2024
b8d2568
PR Changes
Abhishek-TAMU Oct 15, 2024
ae11c96
PR Changes
Abhishek-TAMU Oct 15, 2024
76c51ca
PR Changes
Abhishek-TAMU Oct 15, 2024
2a69f6c
Merge branch 'huggingface:main' into compile_llama
Abhishek-TAMU Oct 15, 2024
77c7a3d
PR Changes
Abhishek-TAMU Oct 16, 2024
5333e89
Merge remote-tracking branch 'huggingface/main' into compile_llama
Abhishek-TAMU Oct 18, 2024
391715a
addition of documentation
Abhishek-TAMU Oct 18, 2024
f23c955
change in _flash_attention_forward
Abhishek-TAMU Oct 21, 2024
ba54841
Merge remote-tracking branch 'huggingface/main' into compile_llama
Abhishek-TAMU Oct 22, 2024
67c7828
make fix-copies
Abhishek-TAMU Oct 22, 2024
8d2ec29
revert make fix-copies
Abhishek-TAMU Oct 22, 2024
480c78d
Merge remote-tracking branch 'huggingface/main' into compile_llama
Abhishek-TAMU Oct 23, 2024
5a903da
fix copies
ArthurZucker Oct 23, 2024
05f9a80
style
ArthurZucker Oct 23, 2024
6843a9c
Merge branch 'main' of github.com:huggingface/transformers into compi…
ArthurZucker Oct 23, 2024
a6e2601
loss kwargs typing
ArthurZucker Oct 23, 2024
dd0bd9a
Merge branch 'main' of github.com:huggingface/transformers into compi…
ArthurZucker Oct 24, 2024
cb08b63
style and pull latest changes
ArthurZucker Oct 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions src/transformers/modeling_flash_attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,10 @@ def prepare_fa2_from_position_ids(query, key, value, position_ids):
return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length))


flash_241 = is_flash_attn_greater_or_equal("2.4.1")
deterministic_g = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
Comment on lines +183 to +184
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch this was breaking compile on my side as well



def _flash_attention_forward(
query_states: torch.Tensor,
key_states: torch.Tensor,
Expand All @@ -194,6 +198,11 @@ def _flash_attention_forward(
use_top_left_mask: bool = False,
softcap: Optional[float] = None,
deterministic: bool = None,
cu_seq_lens_q: Optional[torch.LongTensor] = None,
cu_seq_lens_k: Optional[torch.LongTensor] = None,
max_length_q: int = 0,
max_length_k: int = 0,
batch_size: int = 2,
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
Expand Down Expand Up @@ -232,9 +241,9 @@ def _flash_attention_forward(
)
flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}

if is_flash_attn_greater_or_equal("2.4.1"):
if flash_241:
if deterministic is None:
deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
deterministic = deterministic_g
flash_kwargs["deterministic"] = deterministic

if softcap is not None:
Expand Down Expand Up @@ -267,24 +276,15 @@ def _flash_attention_forward(
# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
# Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
# Note: the `torch.diff(...)` condition is last to use short-circuit and avoid the cuda synchronization it incurs during inference (query_length == 1 always)
elif position_ids is not None and query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all():
batch_size = query_states.size(0)
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
query_states, key_states, value_states, position_ids
)

cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should keep this in case cu_seq_lens_q/k and etc are not passsed to compute them!


elif position_ids is not None and max_length_q is not None:
attn_output = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
cu_seqlens_q=cu_seq_lens_q,
cu_seqlens_k=cu_seq_lens_k,
max_seqlen_q=max_length_q,
max_seqlen_k=max_length_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
Expand Down
18 changes: 16 additions & 2 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union, Unpack
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from typing import List, Optional, Tuple, Union, Unpack
from typing import List, Optional, Tuple, Union


import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -48,6 +48,9 @@
logging,
replace_return_docstrings,
)
from ...processing_utils import (
FlashAttentionKwargs,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from ...processing_utils import (
FlashAttentionKwargs,
)
from ...processing_utils import (
FlashAttentionKwargs, Unpack
)

from .configuration_llama import LlamaConfig


Expand Down Expand Up @@ -421,6 +424,7 @@ def forward(
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if isinstance(past_key_value, StaticCache):
raise ValueError(
Expand Down Expand Up @@ -494,6 +498,11 @@ def forward(
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)

batch_size = query_states.size(0)
query_states = query_states.reshape(-1, query_states.size(-2), query_states.size(-1))
key_states = key_states.reshape(-1, key_states.size(-2), key_states.size(-1))
value_states = value_states.reshape(-1, value_states.size(-2), value_states.size(-1))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
batch_size = query_states.size(0)
query_states = query_states.reshape(-1, query_states.size(-2), query_states.size(-1))
key_states = key_states.reshape(-1, key_states.size(-2), key_states.size(-1))
value_states = value_states.reshape(-1, value_states.size(-2), value_states.size(-1))

I think we should do this in the _flash_attention_forward wrapper this we have 0 modeling changes, and all models will benefit easily from this!


attn_output = _flash_attention_forward(
query_states,
key_states,
Expand All @@ -505,6 +514,8 @@ def forward(
sliding_window=getattr(self, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
batch_size=batch_size,
**kwargs
)

attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
Expand Down Expand Up @@ -868,7 +879,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down Expand Up @@ -950,6 +961,7 @@ def forward(
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs
)

hidden_states = layer_outputs[0]
Expand Down Expand Up @@ -1146,6 +1158,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -1195,6 +1208,7 @@ def forward(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**flash_attn_kwargs,
)

hidden_states = outputs[0]
Expand Down
21 changes: 21 additions & 0 deletions src/transformers/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import sys
import typing
import warnings
import torch
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union

Expand Down Expand Up @@ -77,6 +78,26 @@
else:
Unpack = typing_extensions.Unpack

class FlashAttentionKwargs(TypedDict, total=False):
"""
Keyword arguments for Flash Attention with Compile.

Attributes:
cu_seq_lens_q (`torch.LongTensor`, *optional*)
Gets cumlative sequence length for query state.
cu_seq_lens_k (`torch.LongTensor`, *optional*)
Gets cumlative sequence length for key state.
max_length_q (`int`, *optional*):
Maximum sequence length for query state.
max_length_k (`int`, *optional*):
Maximum sequence length for key state.
"""

cu_seq_lens_q: Optional[torch.LongTensor]
cu_seq_lens_k: Optional[torch.LongTensor]
max_length_q: Optional[int]
max_length_k: Optional[int]


class TextKwargs(TypedDict, total=False):
"""
Expand Down
5 changes: 4 additions & 1 deletion src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union

import numpy as np
import torch
from packaging import version

from . import __version__
Expand Down Expand Up @@ -815,7 +816,9 @@ def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding":
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs
# into a HalfTensor
if isinstance(device, str) or is_torch_device(device) or isinstance(device, int):
self.data = {k: v.to(device=device) for k, v in self.data.items() if isinstance(v, torch.Tensor)}
self.data = {
k: v.to(device=device) for k, v in self.data.items() if v is not None and isinstance(v, torch.Tensor)
}
else:
logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.")
return self
Expand Down
Loading