Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
d541997
fix: fixes for graph breaks
Abhishek-TAMU Oct 3, 2024
35b2aa6
fix: formatting
Abhishek-TAMU Oct 3, 2024
5cefb84
fix: import error
Abhishek-TAMU Oct 3, 2024
aa7b014
fix: Add Fa2Kwargs
Abhishek-TAMU Oct 7, 2024
c42deaa
Merge branch 'main' into compile_llama
Abhishek-TAMU Oct 8, 2024
926481b
fix: PR Changes
Abhishek-TAMU Oct 9, 2024
85f1330
Merge branch 'compile_llama' of https://github.com/Abhishek-TAMU/tran…
Abhishek-TAMU Oct 9, 2024
01fb377
Merge branch 'main' into compile_llama
Abhishek-TAMU Oct 9, 2024
5ec657f
Merge branch 'main' into compile_llama
Abhishek-TAMU Oct 10, 2024
20a4dd6
PR changes
Abhishek-TAMU Oct 10, 2024
045ef16
PR changes
Abhishek-TAMU Oct 10, 2024
d2796f6
PR changes
Abhishek-TAMU Oct 11, 2024
39d2868
PR changes
Abhishek-TAMU Oct 11, 2024
83747b5
Revert "PR changes"
Abhishek-TAMU Oct 11, 2024
b642d45
PR changes
Abhishek-TAMU Oct 11, 2024
d760818
Merge branch 'huggingface:main' into compile_llama
Abhishek-TAMU Oct 14, 2024
d03e673
fix: FlashAttentionKwarg
Abhishek-TAMU Oct 14, 2024
91f6fa1
Merge branch 'compile_llama' of https://github.com/Abhishek-TAMU/tran…
Abhishek-TAMU Oct 14, 2024
80e0d5f
fix: FlashAttentionKwarg
Abhishek-TAMU Oct 14, 2024
ca42b8b
PR Changes
Abhishek-TAMU Oct 15, 2024
b8d2568
PR Changes
Abhishek-TAMU Oct 15, 2024
ae11c96
PR Changes
Abhishek-TAMU Oct 15, 2024
76c51ca
PR Changes
Abhishek-TAMU Oct 15, 2024
2a69f6c
Merge branch 'huggingface:main' into compile_llama
Abhishek-TAMU Oct 15, 2024
77c7a3d
PR Changes
Abhishek-TAMU Oct 16, 2024
5333e89
Merge remote-tracking branch 'huggingface/main' into compile_llama
Abhishek-TAMU Oct 18, 2024
391715a
addition of documentation
Abhishek-TAMU Oct 18, 2024
f23c955
change in _flash_attention_forward
Abhishek-TAMU Oct 21, 2024
ba54841
Merge remote-tracking branch 'huggingface/main' into compile_llama
Abhishek-TAMU Oct 22, 2024
67c7828
make fix-copies
Abhishek-TAMU Oct 22, 2024
8d2ec29
revert make fix-copies
Abhishek-TAMU Oct 22, 2024
480c78d
Merge remote-tracking branch 'huggingface/main' into compile_llama
Abhishek-TAMU Oct 23, 2024
5a903da
fix copies
ArthurZucker Oct 23, 2024
05f9a80
style
ArthurZucker Oct 23, 2024
6843a9c
Merge branch 'main' of github.com:huggingface/transformers into compi…
ArthurZucker Oct 23, 2024
a6e2601
loss kwargs typing
ArthurZucker Oct 23, 2024
dd0bd9a
Merge branch 'main' of github.com:huggingface/transformers into compi…
ArthurZucker Oct 24, 2024
cb08b63
style and pull latest changes
ArthurZucker Oct 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions docs/source/en/llm_optims.md
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,99 @@ model = AutoModelForCausalLM.from_pretrained(
)
```

### Fine-Tuning with torch.compile and Padding-Free Data Collation

In addition to optimizing inference, you can also enhance the training efficiency of large language models by leveraging torch.compile during fine-tuning and using a padding-free data collator. This approach can significantly speed up training and reduce computational overhead.

Here's how you can fine-tune a Llama model using SFTTrainer from the TRL library, with torch_compile enabled and a padding-free data collator:

```
#################### IMPORTS ###################
import math
import datasets
import dataclasses
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments
)
from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM
#################### MODEL LOADING WITH FLASH ATTENTION ###################
model_name = "meta-llama/Llama-3.2-1B"
model = AutoModelForCausalLM.from_pretrained(
model_name,
attn_implementation="flash_attention_2" # Enables FlashAttention-2
)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
#################### DATA PREPROCESSING (PADDING-FREE) ###################
response_template = "\n### Label:"
response_template_ids = tokenizer.encode(
response_template, add_special_tokens=False
)[2:] # Exclude special tokens
data_collator = DataCollatorForCompletionOnlyLM(
response_template_ids=response_template_ids,
tokenizer=tokenizer,
ignore_index=-100,
padding_free=True # Enables padding-free collation
)
def format_dataset(example):
return {
"output": example["output"] + tokenizer.eos_token
}
data_files = {"train": "path/to/dataset"} # Replace with your dataset path
json_dataset = datasets.load_dataset("json", data_files=data_files)
formatted_train_dataset = json_dataset["train"].map(format_dataset)
################# TRAINING CONFIGURATION ############################
train_args = TrainingArguments(
num_train_epochs=5,
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=1e-5,
weight_decay=0.0,
warmup_ratio=0.03,
lr_scheduler_type="cosine",
logging_steps=1,
include_tokens_per_second=True,
save_strategy="epoch",
output_dir="output",
torch_compile=True, # Enables torch.compile
torch_compile_backend="inductor",
torch_compile_mode="default"
)
# Convert TrainingArguments to SFTConfig
transformer_train_arg_fields = [x.name for x in dataclasses.fields(SFTConfig)]
transformer_kwargs = {
k: v
for k, v in train_args.to_dict().items()
if k in transformer_train_arg_fields
}
training_args = SFTConfig(**transformer_kwargs)
####################### FINE-TUNING #####################
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=formatted_train_dataset,
data_collator=data_collator,
dataset_text_field="output",
args=training_args,
)
trainer.train()
```

### PyTorch scaled dot product attention

Scaled dot product attention (SDPA) is automatically enabled in PyTorch 2.0 and it supports FlashAttention, xFormers, and PyTorch's C++ implementation. SDPA chooses the most performant attention algorithm if you're using a CUDA backend. For other backends, SDPA defaults to the PyTorch C++ implementation.
Expand Down
65 changes: 51 additions & 14 deletions src/transformers/modeling_flash_attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import inspect
import os
from typing import Optional, Tuple
from typing import Optional, Tuple, TypedDict

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -180,6 +180,10 @@ def prepare_fa2_from_position_ids(query, key, value, position_ids):
return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length))


flash_241 = is_flash_attn_greater_or_equal("2.4.1")
deterministic_g = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
Comment on lines +183 to +184
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch this was breaking compile on my side as well



def _flash_attention_forward(
query_states: torch.Tensor,
key_states: torch.Tensor,
Expand All @@ -194,6 +198,10 @@ def _flash_attention_forward(
use_top_left_mask: bool = False,
softcap: Optional[float] = None,
deterministic: bool = None,
cu_seq_lens_q: Optional[torch.LongTensor] = None,
cu_seq_lens_k: Optional[torch.LongTensor] = None,
max_length_q: Optional[int] = None,
max_length_k: Optional[int] = None,
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
Expand Down Expand Up @@ -232,9 +240,9 @@ def _flash_attention_forward(
)
flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}

if is_flash_attn_greater_or_equal("2.4.1"):
if flash_241:
if deterministic is None:
deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
deterministic = deterministic_g
flash_kwargs["deterministic"] = deterministic

if softcap is not None:
Expand Down Expand Up @@ -267,24 +275,32 @@ def _flash_attention_forward(
# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
# Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
# Note: the `torch.diff(...)` condition is last to use short-circuit and avoid the cuda synchronization it incurs during inference (query_length == 1 always)
elif position_ids is not None and query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all():
elif position_ids is not None and (
max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all())
):
batch_size = query_states.size(0)
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
query_states, key_states, value_states, position_ids
)

cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
if cu_seq_lens_q is None or cu_seq_lens_k is None:
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = (
prepare_fa2_from_position_ids(query_states, key_states, value_states, position_ids)
)

cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens
max_length_q, max_length_k = max_seq_lens

else:
query_states = query_states.reshape(-1, query_states.size(-2), query_states.size(-1))
key_states = key_states.reshape(-1, key_states.size(-2), key_states.size(-1))
value_states = value_states.reshape(-1, value_states.size(-2), value_states.size(-1))

attn_output = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
cu_seqlens_q=cu_seq_lens_q,
cu_seqlens_k=cu_seq_lens_k,
max_seqlen_q=max_length_q,
max_seqlen_k=max_length_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
Expand All @@ -299,3 +315,24 @@ def _flash_attention_forward(
)

return attn_output


class FlashAttentionKwargs(TypedDict, total=False):
"""
Keyword arguments for Flash Attention with Compile.
Attributes:
cu_seq_lens_q (`torch.LongTensor`, *optional*)
Gets cumlative sequence length for query state.
cu_seq_lens_k (`torch.LongTensor`, *optional*)
Gets cumlative sequence length for key state.
max_length_q (`int`, *optional*):
Maximum sequence length for query state.
max_length_k (`int`, *optional*):
Maximum sequence length for key state.
"""

cu_seq_lens_q: Optional[torch.LongTensor]
cu_seq_lens_k: Optional[torch.LongTensor]
max_length_q: Optional[int]
max_length_k: Optional[int]
4 changes: 4 additions & 0 deletions src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,14 @@
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import (
add_start_docstrings,
Expand Down Expand Up @@ -832,6 +834,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down Expand Up @@ -913,6 +916,7 @@ def forward(
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)

hidden_states = layer_outputs[0]
Expand Down
14 changes: 13 additions & 1 deletion src/transformers/models/glm/modeling_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
)
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
Expand All @@ -51,7 +52,11 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward

from ...modeling_flash_attention_utils import _flash_attention_forward
from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward
from ...processing_utils import Unpack


_CHECKPOINT_FOR_DOC = "dummy"


class GlmRMSNorm(nn.Module):
Expand Down Expand Up @@ -736,6 +741,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down Expand Up @@ -817,6 +823,7 @@ def forward(
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)

hidden_states = layer_outputs[0]
Expand Down Expand Up @@ -1222,6 +1229,11 @@ def set_input_embeddings(self, value):
self.model.embed_tokens = value

@add_start_docstrings_to_model_forward(GLM_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/glm/modular_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@

logger = logging.get_logger(__name__)

_CHECKPOINT_FOR_DOC = "dummy"


class GlmRMSNorm(Phi3RMSNorm):
pass
Expand Down
16 changes: 13 additions & 3 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import _flash_attention_forward
from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
Expand All @@ -39,8 +39,10 @@
)
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import (
LossKwargs,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
Expand Down Expand Up @@ -422,6 +424,7 @@ def forward(
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if isinstance(past_key_value, StaticCache):
raise ValueError(
Expand Down Expand Up @@ -506,6 +509,7 @@ def forward(
sliding_window=getattr(self, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
**kwargs,
)

attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
Expand Down Expand Up @@ -870,6 +874,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down Expand Up @@ -951,6 +956,7 @@ def forward(
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)

hidden_states = layer_outputs[0]
Expand Down Expand Up @@ -1102,6 +1108,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
return causal_mask


class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...


class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]

Expand Down Expand Up @@ -1148,7 +1157,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -1198,6 +1207,7 @@ def forward(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)

hidden_states = outputs[0]
Expand All @@ -1211,7 +1221,7 @@ def forward(

loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **loss_kwargs)
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,7 +815,7 @@ def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding":
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs
# into a HalfTensor
if isinstance(device, str) or is_torch_device(device) or isinstance(device, int):
self.data = {k: v.to(device=device) for k, v in self.data.items() if isinstance(v, torch.Tensor)}
self.data = {k: v.to(device=device) if isinstance(v, torch.Tensor) else v for k, v in self.data.items()}
else:
logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.")
return self
Expand Down
Loading
Loading