-
Notifications
You must be signed in to change notification settings - Fork 29.3k
🚨🚨[core] Completely rewrite the masking logic for all attentions #37866
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
53ca556
to
ce42aa7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks very vey nice!
One thing I want to consider is to rather call the sliding
, causal
and chuncking
directly in the modeling.
For example:
- llama only need
causal_mask
, under the hood the causal mask should do an and with sddpa or cflash or flex - gemma need sliding_causal: same
- llama4 needs chuncked causal
I want the modeling to call an explicit function, rather than the mega general one!
This would keep our philosophy, as we don't want too general stuff hapenning when not needed (ex: llama should never care about sliding in codepathes)
Also misssing doc about how to add a new func!
Wow!!!!!!!! 🚀 This PR seems worth a manually full CI. Ping me when it's time you think this PR is ready to trigger CI. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Damn nice
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
REview for the core logic, IMO can be simplified! BUt the modeling part is absolutely perfect!
For the visualization, I'll see how we could just overwrite the repr without affecting other operations!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a quick question on this refactor: If I understand the code correctly, then the focus is currently on causal masks only, correct?
Would be nice to add a non-causal alternative which should only use a padding mask and expand respectively to the q_len and kv_len. That's more food for thought :D I dont want to make this PR even harder than it is.
0b6bbe5
to
7fc4f91
Compare
For now it's mostly on causal masks because they are the one we need, but the idea is that it can be extended super easily from a set of mask primitives! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mega nice!
TODO before merging:
- move the
causal_mask_mapping
to a class attribute! - show example of how to register a new function, but minimal (without sdpa correction for example)
- Make sure full graph training is not broken maybe? or at least fa2 training
That should be i
28e232c
to
5170e9d
Compare
dee568c
to
4a2e906
Compare
ffdd142
to
343ab95
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's go!
print("I just entered the attention mask computation") | ||
return sdpa_mask(*args, **kwargs) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's rather show how to do something like the paligemma or document masking here, something relevant!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Those are a bit different, it's modifying the mask pattern vs adding a new mask format for the attention itself (both are complementary)
Hi @Cyrilvallez I noticed that after this PR, calling Reproducer: from transformers import AutoModelForCausalLM
from transformers.cache_utils import HybridCache
import torch
model_id = 'hf-internal-testing/tiny-random-Gemma3ForCausalLM'
model = AutoModelForCausalLM.from_pretrained(model_id)
inputs = torch.arange(6).view(2, 3)
attention_mask = torch.ones_like(inputs)
# cache is required, w/o cache a tensor is returned as expected
cache = HybridCache(model.config, max_batch_size=2, max_cache_len=3)
model_kwargs = model.prepare_inputs_for_generation(
inputs, attention_mask=attention_mask, past_key_values=cache, cache_position=torch.arange(3)
)
mask = model_kwargs['attention_mask']
assert isinstance(mask, torch.Tensor), f"expected attention mask to be tensor, got {mask}" Before the PR (
|
if not hasattr(model.config, "layer_types"): | ||
# If `layer_types` is not specified explicitly in the config, there is only 1 type of layers, so | ||
# export will use `StaticCache` by default. | ||
logging.info("Using `StaticCache` for export as `layer_types` is not specified in the config.") | ||
self.model = TorchExportableModuleWithStaticCache(model) | ||
else: | ||
if model.config.cache_implementation == "hybrid": | ||
self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len) | ||
else: | ||
raise ValueError( | ||
f"Unsupported cache implementation: {model.config.cache_implementation}. " | ||
"Please use `hybrid` or `static`." | ||
) | ||
self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Cyrilvallez What is layer_types
? I'm concerning whether changes here are backwards compatible. For existing models on Hub like google/gemma-3-1b, it doesn't seem to come with the layer_types
so it will fallback to the static cache which doesn't look correct.
One comment i have is that the way mask calculation is incorporated in most models is that the calculation of mask happens at model level. e.g. here https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma3/modeling_gemma3.py#L565-L566, however, different cache implementations may imply different attention masks. Different layers may have different cache impl, for example some layers can have sliding window of different size, others may use attention sink to keep say first few or some tokens. I feel the best way for the custom mask is at the attention layer so that the said layer can pass in all the information, including kv cache, to the custom mask function (e.g. layer_index). |
Hey, sorry all I was on vacations! @BenjaminBossan indeed, this is expected. Models using different types of attention in different layers (i.e. gemma3) will now have a dict returned by @guangy10 if you look at the configs, e.g. here, you'll see that the attribute was added in a BC manner for all models that were refactored! Let me know if you notice any issue though! @kimishpatel In transformers, Cache are not at the layer level, so as of now only some configurations are acceptable (though I've had in mind to change that for some time, to make it more modular). And computing the mask at the AttentionLayer-level is not only redundant (most layers will create the same mask, wasting precious time), but it breaks compile completely, as we cannot pre-compute the masks anymore. For now, there are no known models with sliding windows of different sizes for different layers, so we decided to make it as simple as possible. This was taken into account when doing this refactor though, no worries, we definitely thought about it to scale easily in the future should this scenario happen |
Resolves CI errors such as this one: https://github.com/huggingface/peft/actions/runs/15481482956/job/43588020111#step:5:53182 After resolving that error, other errors can occur, but they're unrelated and investigated independently. After the transformers change in huggingface/transformers#37866, it can happen that: > Models using different types of attention in different layers (i.e. gemma3) will now have a dict returned by prepare_inputd_for_generation (one dict entry per attention type) As PEFT operates on the attention mask for prompt learning methods, we need to adjust the code for the possibility of attention_mask being a dict. Right now, I simply extract the single value if the dict is just one element. For other sizes, I just raise an error, as I don't know how to deal with that. For our tests, this is enough but we might need to find a better solution in the future.
…gingface#37866) * start * start having a clean 4d mask primitive * Update mask_utils.py * Update mask_utils.py * switch name * Update masking_utils.py * add a new AttentionMask tensor class * fix import * nits * fixes * use full and quandrants * general sdpa mask for all caches * style * start some tests * tests with sliding, chunked * add styling * test hybrid * Update masking_utils.py * small temp fixes * Update modeling_gemma2.py * compile compatible * Update masking_utils.py * improve * start making it more general * Update masking_utils.py * generate * make it work with flex style primitives! * Update masking_utils.py * Update masking_utils.py * Update masking_utils.py * improve * Update cache_utils.py * Update masking_utils.py * simplify - starting to look good! * Update masking_utils.py * name * Update masking_utils.py * style * Update masking_utils.py * Update masking_utils.py * Update masking_utils.py * Update masking_utils.py * small fix for flex * flex compile * FA2 * Update masking_utils.py * Escape for TGI/vLLM! * Update masking_utils.py * Update masking_utils.py * Update masking_utils.py * General case without cache * rename * full test on llama4 * small fix for FA2 guard with chunk * Update modeling_gemma2.py * post rebase cleanup * FA2 supports static cache! * Update modeling_flash_attention_utils.py * Update flex_attention.py * Update masking_utils.py * Update masking_utils.py * Update utils.py * override for export * Update executorch.py * Update executorch.py * Update executorch.py * Update executorch.py * Update masking_utils.py * Update masking_utils.py * output attentions * style * Update masking_utils.py * Update executorch.py * Add doicstring * Add license and put mask visualizer at the end * Update test_modeling_common.py * fix broken test * Update test_modeling_gemma.py * Update test_modeling_gemma2.py * Use fullgraph=False with FA2 * Update utils.py * change name * Update masking_utils.py * improve doc * change name * Update modeling_attn_mask_utils.py * more explicit logic based on model's property * pattern in config * extend * fixes * make it better * generalize to other test models * fix * Update masking_utils.py * fix * do not check mask equivalence if layer types are different * executorch * Update modeling_gemma2.py * Update masking_utils.py * use layer_idx instead * adjust * Update masking_utils.py * test * fix imports * Update modeling_gemma2.py * other test models * Update modeling_llama4.py * Update masking_utils.py * improve * simplify * Update masking_utils.py * typos * typo * fix * Update masking_utils.py * default DynamicCache * remove default cache * simplify * Update masking_utils.py * Update masking_utils.py * Update masking_utils.py * Update masking_utils.py * simplify * Update masking_utils.py * Update masking_utils.py * Update masking_utils.py * export * Update executorch.py * Update executorch.py * Update flex_attention.py * Update executorch.py * upstream to modular gemma 1 & 2 * Update modular_mistral.py * switch names * use dict * put it in the Layer directly * update copy model source for mask functions * apply so many modular (hopefully 1 shot) * use explicite dicts for make style happy * protect import * check docstring * better default in hybrid caches * qwens * Update modular_qwen2.py * simplify core logic! * Update executorch.py * qwen3 moe * Update masking_utils.py * Update masking_utils.py * simplify a lot sdpa causal skip * Update masking_utils.py * post-rebase * gemma3 finally * style * check it before * gemma3 * More general with newer torch * align gemma3 * Update utils.py * Update utils.py * Update masking_utils.py * Update test_modeling_common.py * Update flex_attention.py * Update flex_attention.py * Update flex_attention.py * test * executorch * Update test_modeling_common.py * Update masking_utils.py * Update masking_utils.py * Update masking_utils.py * Update masking_utils.py * Update executorch.py * Update test_modeling_common.py * fix copies * device * sdpa can be used without mask -> pass the torchscript tests in this case * Use enum for check * revert enum and add check instead * remove broken test * cohere2 * some doc & reorganize the Interface * Update tensor_parallel.py * Update tensor_parallel.py * doc and dummy * Update test_modeling_paligemma2.py * Update modeling_falcon_h1.py * Update masking_utils.py * executorch patch * style * CIs * use register in executorch * final comments! --------- Co-authored-by: Arthur Zucker <[email protected]>
To be fair, I doubt attention mask calculation has that much impact on performance for the most models. I have implemented ring buffer based kv cache, that needs a very different way of calculating mask and that mask calculation, while redundant, happens at attention layer. I have not observed any significant amount of time spent in there. Although I think for block mask in flex attention, you might be right. That one is non-trivial.
how so? cache's I do understand though that transformers is not exactly providing building blocks for model authoring so from that perspective composability and modularity has limited value i suppose |
@Cyrilvallez Let's follow up in #38646 |
What does this PR do?
As per the title. The goal is to properly separate masking logic from modeling code itself, to continue our objective of simplifying the library.
generate