Skip to content

Commit c800883

Browse files
CyrilvallezArthurZucker
authored andcommitted
🚨🚨[core] Completely rewrite the masking logic for all attentions (huggingface#37866)
* start * start having a clean 4d mask primitive * Update mask_utils.py * Update mask_utils.py * switch name * Update masking_utils.py * add a new AttentionMask tensor class * fix import * nits * fixes * use full and quandrants * general sdpa mask for all caches * style * start some tests * tests with sliding, chunked * add styling * test hybrid * Update masking_utils.py * small temp fixes * Update modeling_gemma2.py * compile compatible * Update masking_utils.py * improve * start making it more general * Update masking_utils.py * generate * make it work with flex style primitives! * Update masking_utils.py * Update masking_utils.py * Update masking_utils.py * improve * Update cache_utils.py * Update masking_utils.py * simplify - starting to look good! * Update masking_utils.py * name * Update masking_utils.py * style * Update masking_utils.py * Update masking_utils.py * Update masking_utils.py * Update masking_utils.py * small fix for flex * flex compile * FA2 * Update masking_utils.py * Escape for TGI/vLLM! * Update masking_utils.py * Update masking_utils.py * Update masking_utils.py * General case without cache * rename * full test on llama4 * small fix for FA2 guard with chunk * Update modeling_gemma2.py * post rebase cleanup * FA2 supports static cache! * Update modeling_flash_attention_utils.py * Update flex_attention.py * Update masking_utils.py * Update masking_utils.py * Update utils.py * override for export * Update executorch.py * Update executorch.py * Update executorch.py * Update executorch.py * Update masking_utils.py * Update masking_utils.py * output attentions * style * Update masking_utils.py * Update executorch.py * Add doicstring * Add license and put mask visualizer at the end * Update test_modeling_common.py * fix broken test * Update test_modeling_gemma.py * Update test_modeling_gemma2.py * Use fullgraph=False with FA2 * Update utils.py * change name * Update masking_utils.py * improve doc * change name * Update modeling_attn_mask_utils.py * more explicit logic based on model's property * pattern in config * extend * fixes * make it better * generalize to other test models * fix * Update masking_utils.py * fix * do not check mask equivalence if layer types are different * executorch * Update modeling_gemma2.py * Update masking_utils.py * use layer_idx instead * adjust * Update masking_utils.py * test * fix imports * Update modeling_gemma2.py * other test models * Update modeling_llama4.py * Update masking_utils.py * improve * simplify * Update masking_utils.py * typos * typo * fix * Update masking_utils.py * default DynamicCache * remove default cache * simplify * Update masking_utils.py * Update masking_utils.py * Update masking_utils.py * Update masking_utils.py * simplify * Update masking_utils.py * Update masking_utils.py * Update masking_utils.py * export * Update executorch.py * Update executorch.py * Update flex_attention.py * Update executorch.py * upstream to modular gemma 1 & 2 * Update modular_mistral.py * switch names * use dict * put it in the Layer directly * update copy model source for mask functions * apply so many modular (hopefully 1 shot) * use explicite dicts for make style happy * protect import * check docstring * better default in hybrid caches * qwens * Update modular_qwen2.py * simplify core logic! * Update executorch.py * qwen3 moe * Update masking_utils.py * Update masking_utils.py * simplify a lot sdpa causal skip * Update masking_utils.py * post-rebase * gemma3 finally * style * check it before * gemma3 * More general with newer torch * align gemma3 * Update utils.py * Update utils.py * Update masking_utils.py * Update test_modeling_common.py * Update flex_attention.py * Update flex_attention.py * Update flex_attention.py * test * executorch * Update test_modeling_common.py * Update masking_utils.py * Update masking_utils.py * Update masking_utils.py * Update masking_utils.py * Update executorch.py * Update test_modeling_common.py * fix copies * device * sdpa can be used without mask -> pass the torchscript tests in this case * Use enum for check * revert enum and add check instead * remove broken test * cohere2 * some doc & reorganize the Interface * Update tensor_parallel.py * Update tensor_parallel.py * doc and dummy * Update test_modeling_paligemma2.py * Update modeling_falcon_h1.py * Update masking_utils.py * executorch patch * style * CIs * use register in executorch * final comments! --------- Co-authored-by: Arthur Zucker <[email protected]>
1 parent ce808ad commit c800883

File tree

129 files changed

+2984
-6808
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

129 files changed

+2984
-6808
lines changed

‎docs/source/en/attention_interface.md

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,4 +125,44 @@ would expect from a usual Python dictionary:
125125

126126
# You can also globally `register` a new function directly on it
127127
>>> ALL_ATTENTION_FUNCTIONS.register("new_func", new_func)
128-
```
128+
```
129+
130+
## Attention Mask Interface
131+
132+
Having a new attention function may mean that you need a new format of attention mask to decide what key and value tokens
133+
the query tokens should attend to. This is now possible with the `AttentionMaskInterface`! It works in the same way as
134+
the `AttentionInterface`:
135+
136+
```python
137+
from transformers import AttentionMaskInterface
138+
from transformers.masking_utils import sdpa_mask
139+
import torch
140+
141+
def my_new_sdpa_mask(*args, **kwargs):
142+
print("I just entered the attention mask computation")
143+
return sdpa_mask(*args, **kwargs)
144+
145+
AttentionMaskInterface.register("my_new_sdpa_mask", my_new_sdpa_mask)
146+
```
147+
148+
The reason you have to register it is because we need to automatically correct your mask format based on the attention implementation (for example, flex attention uses a BlockMask format, while sdpa uses a 4D tensor).
149+
By default, if you do not register an attention mask function along with your attention function, mask creation will be skipped
150+
and `attention_mask=None` will be passed along to the Attention layers.
151+
152+
The default signature of the attention mask functions is the following:
153+
154+
```python
155+
def custom_attention_mask(
156+
batch_size: int, # required arg
157+
cache_position: torch.Tensor, # required arg
158+
kv_length: int, # required arg
159+
kv_offset: int = 0, # required arg
160+
mask_function: Callable = causal_mask_function, # required arg
161+
attention_mask: Optional[torch.Tensor] = None, # required arg
162+
**kwargs, # a few additional args may be passed as kwargs, especially the model's config is always passed
163+
) -> Optional[torch.Tensor]:
164+
```
165+
166+
It mostly works thanks to the `mask_function`, which is a `Callable` in the form of [torch's mask_mod functions](https://pytorch.org/blog/flexattention/), taking 4 indices as input and returning a boolean to indicate if this position should take part in the attention computation.
167+
168+
If you cannot use the `mask_function` to create your mask for some reason, you can try to work around it by doing something similar to our [torch export workaround](https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/executorch.py).

‎docs/source/en/internal/modeling_utils.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@ Most of those are only useful if you are studying the code of the models in the
2929
[[autodoc]] AttentionInterface
3030
- register
3131

32+
## Attention Mask Functions
33+
34+
[[autodoc]] AttentionMaskInterface
35+
- register
36+
3237
## Rotary Position Embedding Functions
3338

3439
[[autodoc]] dynamic_rope_update

‎src/transformers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,7 @@
445445
_import_structure["modeling_outputs"] = []
446446
_import_structure["modeling_rope_utils"] = ["ROPE_INIT_FUNCTIONS", "dynamic_rope_update"]
447447
_import_structure["modeling_utils"] = ["PreTrainedModel", "AttentionInterface"]
448+
_import_structure["masking_utils"] = ["AttentionMaskInterface"]
448449
_import_structure["optimization"] = [
449450
"Adafactor",
450451
"get_constant_schedule",
@@ -914,6 +915,7 @@
914915
TorchExportableModuleWithStaticCache,
915916
convert_and_export_with_cache,
916917
)
918+
from .masking_utils import AttentionMaskInterface
917919
from .model_debugging_utils import (
918920
model_addition_debugger_context,
919921
)

‎src/transformers/cache_utils.py

Lines changed: 103 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,18 @@ def seen_tokens(self):
196196
else:
197197
return None
198198

199+
def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
200+
"""
201+
Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
202+
the given layer at `layer_idx`.
203+
The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
204+
for each layer.
205+
"""
206+
query_length = cache_position.shape[0]
207+
past_seen_tokens = self.get_seq_length()
208+
kv_length = query_length + past_seen_tokens
209+
return kv_length, 0
210+
199211

200212
@dataclass
201213
class CacheConfig:
@@ -1084,8 +1096,6 @@ class SinkCache(Cache):
10841096
```
10851097
"""
10861098

1087-
is_sliding = True
1088-
10891099
def __init__(self, window_length: int, num_sink_tokens: int) -> None:
10901100
super().__init__()
10911101
self.key_cache: List[torch.Tensor] = []
@@ -1390,6 +1400,16 @@ def reset(self):
13901400
self.key_cache[layer_idx].zero_()
13911401
self.value_cache[layer_idx].zero_()
13921402

1403+
def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
1404+
"""
1405+
Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
1406+
the given layer at `layer_idx`.
1407+
The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
1408+
for each layer.
1409+
"""
1410+
kv_length = self.get_max_cache_shape()
1411+
return kv_length, 0
1412+
13931413

13941414
class SlidingWindowCache(StaticCache):
13951415
"""
@@ -1446,7 +1466,6 @@ class SlidingWindowCache(StaticCache):
14461466
```
14471467
"""
14481468

1449-
is_sliding = True
14501469
is_compileable = True
14511470

14521471
def __init__(
@@ -1465,6 +1484,7 @@ def __init__(
14651484
"config and it's not set to None."
14661485
)
14671486
max_cache_len = min(config.sliding_window, max_cache_len)
1487+
self.sliding_window = config.sliding_window
14681488
super().__init__(
14691489
config=config,
14701490
max_batch_size=max_batch_size,
@@ -1509,6 +1529,21 @@ def reset(self):
15091529
self.key_cache[layer_idx].zero_()
15101530
self.value_cache[layer_idx].zero_()
15111531

1532+
def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
1533+
"""
1534+
Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
1535+
the given layer at `layer_idx`.
1536+
The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
1537+
for each layer.
1538+
"""
1539+
query_length = cache_position.shape[0]
1540+
first_cache_position = cache_position[0]
1541+
# torch.clamp() is equivalent to max() but should be compile-friendly/exportable as first_cache_position is a Tensor
1542+
kv_offset = torch.clamp(first_cache_position - self.sliding_window + 1, min=0)
1543+
# This is not general (see HybridChunkedCache for the whole general case), but it's what the cache returns
1544+
kv_length = max(query_length, self.get_max_cache_shape())
1545+
return kv_length, kv_offset
1546+
15121547

15131548
class EncoderDecoderCache(Cache):
15141549
"""
@@ -1761,12 +1796,17 @@ def __init__(
17611796
else config.num_key_value_heads
17621797
)
17631798

1764-
layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC
1765-
self.is_sliding_list = [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)]
1799+
# If the attribute does not exist in the config, fallback to a simple StaticCache
1800+
if hasattr(config, "layer_types"):
1801+
self.is_sliding = [layer_type != "full_attention" for layer_type in config.layer_types]
1802+
else:
1803+
self.is_sliding = [False] * config.num_hidden_layers
1804+
17661805
self.key_cache: List[torch.Tensor] = []
17671806
self.value_cache: List[torch.Tensor] = []
17681807
global_cache_shape = (self.max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
17691808
sliding_cache_shape = (self.max_batch_size, self.num_key_value_heads, self.sliding_window_len, self.head_dim)
1809+
self.sliding_window = min(config.sliding_window, max_cache_len)
17701810
device = torch.device(device) if device is not None else None
17711811
for i in range(config.num_hidden_layers):
17721812
if layer_device_map is not None:
@@ -1775,7 +1815,7 @@ def __init__(
17751815
layer_device = device
17761816
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
17771817
# breaks when updating the cache.
1778-
cache_shape = sliding_cache_shape if self.is_sliding_list[i] else global_cache_shape
1818+
cache_shape = sliding_cache_shape if self.is_sliding[i] else global_cache_shape
17791819
new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device)
17801820
new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device)
17811821
torch._dynamo.mark_static_address(new_layer_key_cache)
@@ -1796,7 +1836,7 @@ def update(
17961836
if cache_position is None:
17971837
raise ValueError("`cache_position` must be provided for HybridCache.")
17981838

1799-
is_sliding_layer = self.is_sliding_list[layer_idx]
1839+
is_sliding_layer = self.is_sliding[layer_idx]
18001840

18011841
# These two `if` blocks are only reached in multigpu and if `layer_device_map` is not passed. They are used
18021842
# when the cache is initialized in the forward pass (e.g. Gemma2)
@@ -1843,6 +1883,26 @@ def reset(self):
18431883
self.key_cache[layer_idx].zero_()
18441884
self.value_cache[layer_idx].zero_()
18451885

1886+
def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
1887+
"""
1888+
Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
1889+
the given layer at `layer_idx`.
1890+
The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
1891+
for each layer.
1892+
"""
1893+
if self.is_sliding[layer_idx]:
1894+
query_length = cache_position.shape[0]
1895+
first_cache_position = cache_position[0]
1896+
1897+
local_mask_kv_offset = torch.clamp(first_cache_position - self.sliding_window + 1, min=0)
1898+
# This is not general (see HybridChunkedCache for the whole general case), but it's what the cache returns
1899+
local_mask_kv_length = max(query_length, self.sliding_window)
1900+
return local_mask_kv_length, local_mask_kv_offset
1901+
1902+
full_mask_kv_offset = 0
1903+
full_mask_kv_length = self.get_max_cache_shape()
1904+
return full_mask_kv_length, full_mask_kv_offset
1905+
18461906

18471907
class HybridChunkedCache(Cache):
18481908
"""
@@ -1912,11 +1972,11 @@ def __init__(
19121972
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
19131973
self._dtype = dtype
19141974

1915-
if hasattr(config.get_text_config(), "no_rope_layers"):
1916-
self.is_sliding = config.no_rope_layers
1975+
# If the attribute does not exist in the config, fallback to a simple StaticCache
1976+
if hasattr(config, "layer_types"):
1977+
self.is_sliding = [layer_type != "full_attention" for layer_type in config.layer_types]
19171978
else:
1918-
layer_switch = getattr(config, "sliding_window_pattern", 2)
1919-
self.is_sliding = [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)]
1979+
self.is_sliding = [False] * config.num_hidden_layers
19201980

19211981
self.key_cache: List[torch.Tensor] = []
19221982
self.value_cache: List[torch.Tensor] = []
@@ -1999,11 +2059,7 @@ def update(
19992059
key_states = key_states.to(k_out.dtype)
20002060
value_states = value_states.to(v_out.dtype)
20012061

2002-
if self.is_sliding[layer_idx]:
2003-
update_fn = self._sliding_update
2004-
else:
2005-
update_fn = self._static_update
2006-
2062+
update_fn = self._sliding_update if self.is_sliding[layer_idx] else self._static_update
20072063
return update_fn(
20082064
cache_position,
20092065
layer_idx,
@@ -2038,6 +2094,37 @@ def reset(self):
20382094
self.value_cache[layer_idx].zero_()
20392095
self.cumulative_length = [0 for _ in range(len(self.cumulative_length))]
20402096

2097+
def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
2098+
"""
2099+
Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
2100+
the given layer at `layer_idx`.
2101+
The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
2102+
for each layer.
2103+
"""
2104+
if self.is_sliding[layer_idx]:
2105+
query_length = cache_position.shape[0]
2106+
first_cache_position = cache_position[0]
2107+
2108+
local_mask_kv_offset = torch.clamp(first_cache_position - self.sliding_window + 1, min=0)
2109+
# This is the true general case for any Cache using local attention (sliding or chunked)
2110+
if first_cache_position >= self.sliding_window:
2111+
# Here the Cache is already full
2112+
local_mask_kv_length = self.sliding_window + query_length - 1
2113+
elif (
2114+
first_cache_position < self.sliding_window
2115+
and first_cache_position + query_length > self.sliding_window
2116+
):
2117+
# Here the Cache becomes full with the new input
2118+
local_mask_kv_length = first_cache_position + query_length
2119+
else:
2120+
# Here the Cache is still smaller than the local size, but we return the local size as it's static
2121+
local_mask_kv_length = self.sliding_window
2122+
return local_mask_kv_length, local_mask_kv_offset
2123+
2124+
full_mask_kv_offset = 0
2125+
full_mask_kv_length = self.get_max_cache_shape()
2126+
return full_mask_kv_length, full_mask_kv_offset
2127+
20412128

20422129
class OffloadedHybridCache(HybridChunkedCache):
20432130
def __init__(

‎src/transformers/configuration_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1209,3 +1209,16 @@ def recursive_diff_dict(dict_a, dict_b, config_obj=None):
12091209
PretrainedConfig.push_to_hub.__doc__ = PretrainedConfig.push_to_hub.__doc__.format(
12101210
object="config", object_class="AutoConfig", object_files="configuration file"
12111211
)
1212+
1213+
1214+
ALLOWED_LAYER_TYPES = (
1215+
"full_attention",
1216+
"sliding_attention",
1217+
"chunked_attention",
1218+
)
1219+
1220+
1221+
def layer_type_validation(layer_types: list[str]):
1222+
"""Check that each entry in `layer_types` are allowed."""
1223+
if not all(layer_type in ALLOWED_LAYER_TYPES for layer_type in layer_types):
1224+
raise ValueError(f"The `layer_types` entries must be in {ALLOWED_LAYER_TYPES}")

‎src/transformers/generation/utils.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
)
4747
from ..integrations.deepspeed import is_deepspeed_zero3_enabled
4848
from ..integrations.fsdp import is_fsdp_managed_module
49+
from ..masking_utils import create_masks_for_generate
4950
from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
5051
from ..pytorch_utils import isin_mps_friendly
5152
from ..tokenization_utils import ExtensionsTrie
@@ -74,6 +75,7 @@
7475
from .configuration_utils import (
7576
NEED_SETUP_CACHE_CLASSES_MAPPING,
7677
QUANT_BACKEND_CLASSES_MAPPING,
78+
CompileConfig,
7779
GenerationConfig,
7880
GenerationMode,
7981
)
@@ -649,12 +651,22 @@ def prepare_inputs_for_generation(
649651
causal_mask_creation_function = getattr(
650652
decoder, "_prepare_4d_causal_attention_mask_with_cache_position", None
651653
)
654+
655+
# If it's not defined, it means the model uses the new general mask API
652656
if causal_mask_creation_function is None: # can't be found
653-
logger.warning_once(
654-
f"{self.__class__.__name__} has no `_prepare_4d_causal_attention_mask_with_cache_position` method "
655-
"defined in its base modeling class. Compiled forward passes will be sub-optimal. If you're "
656-
"writing code, see Llama for an example implementation. If you're a user, please report this "
657-
"issue on GitHub."
657+
output_attentions = kwargs.get("output_attentions", False)
658+
token_type_ids = getattr(model_input, "token_type_ids", None)
659+
# Some models may overwrite the general one
660+
causal_mask_creation_function = getattr(self, "create_masks_for_generate", create_masks_for_generate)
661+
attention_mask = causal_mask_creation_function(
662+
config=self.config,
663+
# we only need batch size, seq_length and dtype here - we don't care about the values of the embeddings
664+
input_embeds=torch.empty((batch_size, sequence_length), dtype=self.dtype),
665+
attention_mask=attention_mask,
666+
cache_position=cache_position,
667+
past_key_values=past_key_values,
668+
output_attentions=output_attentions,
669+
token_type_ids=token_type_ids,
658670
)
659671
else:
660672
attention_mask = causal_mask_creation_function(
@@ -3533,6 +3545,19 @@ def _sample(
35333545
compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
35343546
if compile_forward:
35353547
os.environ["TOKENIZERS_PARALLELISM"] = "0"
3548+
# If we use FA2 and a static cache, we cannot compile with fullgraph
3549+
if self.config._attn_implementation == "flash_attention_2" and getattr(
3550+
model_kwargs.get("past_key_values"), "is_compileable", False
3551+
):
3552+
if generation_config.compile_config is None:
3553+
generation_config.compile_config = CompileConfig(fullgraph=False)
3554+
# only raise warning if the user passed an explicit compile-config (otherwise, simply change the default without confusing the user)
3555+
elif generation_config.compile_config.fullgraph:
3556+
logger.warning_once(
3557+
"When using Flash Attention 2 and a static cache, you cannot use the option `CompileConfig(fullgraph=True)` as "
3558+
"FA2 introduces graph breaks. We overrode the option with `fullgraph=False`."
3559+
)
3560+
generation_config.compile_config.fullgraph = False
35363561
model_forward = self.get_compiled_call(generation_config.compile_config)
35373562

35383563
if generation_config.prefill_chunk_size is not None:

0 commit comments

Comments
 (0)