Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
d541997
fix: fixes for graph breaks
Abhishek-TAMU Oct 3, 2024
35b2aa6
fix: formatting
Abhishek-TAMU Oct 3, 2024
5cefb84
fix: import error
Abhishek-TAMU Oct 3, 2024
aa7b014
fix: Add Fa2Kwargs
Abhishek-TAMU Oct 7, 2024
c42deaa
Merge branch 'main' into compile_llama
Abhishek-TAMU Oct 8, 2024
926481b
fix: PR Changes
Abhishek-TAMU Oct 9, 2024
85f1330
Merge branch 'compile_llama' of https://github.com/Abhishek-TAMU/tran…
Abhishek-TAMU Oct 9, 2024
01fb377
Merge branch 'main' into compile_llama
Abhishek-TAMU Oct 9, 2024
5ec657f
Merge branch 'main' into compile_llama
Abhishek-TAMU Oct 10, 2024
20a4dd6
PR changes
Abhishek-TAMU Oct 10, 2024
045ef16
PR changes
Abhishek-TAMU Oct 10, 2024
d2796f6
PR changes
Abhishek-TAMU Oct 11, 2024
39d2868
PR changes
Abhishek-TAMU Oct 11, 2024
83747b5
Revert "PR changes"
Abhishek-TAMU Oct 11, 2024
b642d45
PR changes
Abhishek-TAMU Oct 11, 2024
d760818
Merge branch 'huggingface:main' into compile_llama
Abhishek-TAMU Oct 14, 2024
d03e673
fix: FlashAttentionKwarg
Abhishek-TAMU Oct 14, 2024
91f6fa1
Merge branch 'compile_llama' of https://github.com/Abhishek-TAMU/tran…
Abhishek-TAMU Oct 14, 2024
80e0d5f
fix: FlashAttentionKwarg
Abhishek-TAMU Oct 14, 2024
ca42b8b
PR Changes
Abhishek-TAMU Oct 15, 2024
b8d2568
PR Changes
Abhishek-TAMU Oct 15, 2024
ae11c96
PR Changes
Abhishek-TAMU Oct 15, 2024
76c51ca
PR Changes
Abhishek-TAMU Oct 15, 2024
2a69f6c
Merge branch 'huggingface:main' into compile_llama
Abhishek-TAMU Oct 15, 2024
77c7a3d
PR Changes
Abhishek-TAMU Oct 16, 2024
5333e89
Merge remote-tracking branch 'huggingface/main' into compile_llama
Abhishek-TAMU Oct 18, 2024
391715a
addition of documentation
Abhishek-TAMU Oct 18, 2024
f23c955
change in _flash_attention_forward
Abhishek-TAMU Oct 21, 2024
ba54841
Merge remote-tracking branch 'huggingface/main' into compile_llama
Abhishek-TAMU Oct 22, 2024
67c7828
make fix-copies
Abhishek-TAMU Oct 22, 2024
8d2ec29
revert make fix-copies
Abhishek-TAMU Oct 22, 2024
480c78d
Merge remote-tracking branch 'huggingface/main' into compile_llama
Abhishek-TAMU Oct 23, 2024
5a903da
fix copies
ArthurZucker Oct 23, 2024
05f9a80
style
ArthurZucker Oct 23, 2024
6843a9c
Merge branch 'main' of github.com:huggingface/transformers into compi…
ArthurZucker Oct 23, 2024
a6e2601
loss kwargs typing
ArthurZucker Oct 23, 2024
dd0bd9a
Merge branch 'main' of github.com:huggingface/transformers into compi…
ArthurZucker Oct 24, 2024
cb08b63
style and pull latest changes
ArthurZucker Oct 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions src/transformers/modeling_flash_attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,10 @@ def _flash_attention_forward(
use_top_left_mask: bool = False,
softcap: Optional[float] = None,
deterministic: bool = None,
cu_seqlens_q: Optional[torch.LongTensor] = None,
cu_seqlens_k: Optional[torch.LongTensor] = None,
max_seqlen_in_batch_q: int = 0,
max_seqlen_in_batch_k: int = 0,
cu_seq_lens_q: Optional[torch.LongTensor] = None,
cu_seq_lens_k: Optional[torch.LongTensor] = None,
max_length_q: int = 0,
max_length_k: int = 0,
batch_size: int = 2,
):
"""
Expand Down Expand Up @@ -281,10 +281,10 @@ def _flash_attention_forward(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
cu_seqlens_q=cu_seq_lens_q,
cu_seqlens_k=cu_seq_lens_k,
max_seqlen_q=max_length_q,
max_seqlen_k=max_length_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
Expand Down
42 changes: 9 additions & 33 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@
logging,
replace_return_docstrings,
)
from ...processing_utils import (
Fa2Kwargs,
)
from .configuration_llama import LlamaConfig


Expand Down Expand Up @@ -422,10 +425,7 @@ def forward(
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
cu_seq_lens_q: Optional[torch.LongTensor] = None,
cu_seq_lens_k: Optional[torch.LongTensor] = None,
max_length_q: int = 0,
max_length_k: int = 0,
**kwargs,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
**kwargs,
**kwargs: Unpack[Fa2Kwargs],

) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if isinstance(past_key_value, StaticCache):
raise ValueError(
Expand Down Expand Up @@ -515,11 +515,8 @@ def forward(
sliding_window=getattr(self, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
cu_seqlens_q=cu_seq_lens_q,
cu_seqlens_k=cu_seq_lens_k,
max_seqlen_in_batch_q=max_length_q if isinstance(max_length_q, int) else max_length_q.item(),
max_seqlen_in_batch_k=max_length_k if isinstance(max_length_k, int) else max_length_k.item(),
batch_size=batch_size,
**kwargs
)

attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
Expand Down Expand Up @@ -658,10 +655,6 @@ def forward(
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
cu_seq_lens_q: Optional[torch.LongTensor] = None,
cu_seq_lens_k: Optional[torch.LongTensor] = None,
max_length_q: int = 0,
max_length_k: int = 0,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Expand Down Expand Up @@ -700,10 +693,6 @@ def forward(
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
cu_seq_lens_q=cu_seq_lens_q,
cu_seq_lens_k=cu_seq_lens_k,
max_length_q=max_length_q,
max_length_k=max_length_k,
**kwargs,
)
hidden_states = residual + hidden_states
Expand Down Expand Up @@ -891,11 +880,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
cu_seq_lens_q: Optional[torch.LongTensor] = None,
cu_seq_lens_k: Optional[torch.LongTensor] = None,
max_length_q: int = 0,
max_length_k: int = 0,
**fa2_kwargs: Fa2Kwargs,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
**fa2_kwargs: Fa2Kwargs,
**fa2_kwargs: Unpack[Fa2Kwargs],

) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down Expand Up @@ -979,10 +964,7 @@ def forward(
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
cu_seq_lens_q=cu_seq_lens_q,
cu_seq_lens_k=cu_seq_lens_k,
max_length_q=max_length_q,
max_length_k=max_length_k,
**fa2_kwargs
)

hidden_states = layer_outputs[0]
Expand Down Expand Up @@ -1178,11 +1160,8 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
cu_seq_lens_q: Optional[torch.LongTensor] = None,
cu_seq_lens_k: Optional[torch.LongTensor] = None,
max_length_q: int = 0,
max_length_k: int = 0,
num_logits_to_keep: int = 0,
**fa2_kwargs: Fa2Kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -1232,10 +1211,7 @@ def forward(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
cu_seq_lens_q=cu_seq_lens_q,
cu_seq_lens_k=cu_seq_lens_k,
max_length_q=max_length_q,
max_length_k=max_length_k,
**fa2_kwargs,
)

hidden_states = outputs[0]
Expand Down
21 changes: 21 additions & 0 deletions src/transformers/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import sys
import typing
import warnings
import torch
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union

Expand Down Expand Up @@ -77,6 +78,26 @@
else:
Unpack = typing_extensions.Unpack

class Fa2Kwargs(TypedDict, total=False):
"""
Keyword arguments for Flash Attention with Compile.

Attributes:
cu_seq_lens_q (`torch.LongTensor`, *optional*)
Gets cumlative sequence length for query state.
cu_seq_lens_k (`torch.LongTensor`, *optional*)
Gets cumlative sequence length for key state.
max_length_q (`int`, *optional*):
Maximum sequence length for query state.
max_length_k (`int`, *optional*):
Maximum sequence length for key state.
"""

cu_seq_lens_q: Optional[torch.LongTensor]
cu_seq_lens_k: Optional[torch.LongTensor]
max_length_q: Optional[int]
max_length_k: Optional[int]


class TextKwargs(TypedDict, total=False):
"""
Expand Down