Skip to content

🚨🚨[core] Completely rewrite the masking logic for all attentions #37866

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 193 commits into from
May 22, 2025

Conversation

Cyrilvallez
Copy link
Member

@Cyrilvallez Cyrilvallez commented Apr 29, 2025

What does this PR do?

As per the title. The goal is to properly separate masking logic from modeling code itself, to continue our objective of simplifying the library.

  • Code is much simpler to understand
  • Much more general: always work for all lengths, and all attention implementations, e.g.:
    • flex attention now works with sliding/hybrid models (not the case before)
    • FA2 now works with static caches (including models with default hybrid structures) (was only the case for hybrid models before)
  • All models can use all Cache classes (e.g. models with Hybrid structure can default back to use DynamicCache)
  • Extremely scalable in the future: any pattern of layers can be taken into account WITHOUT ANY CHANGE to modeling or masking. A new masking pattern (e.g. the recently introduced chunked attention for Llama4) can be added with minimal efforts (just add a new mask_mod to describe it, and voila!)
  • A single truth: mask creation was copied over and over again, but sometimes with slight changes to account for sliding windows or similar. This would eventually lead to mistakes or inefficiencies as things would be "forced to fit", and a lot of maintenance burden
  • compile compatible: the new mask creation is technically compile compatible - it should however stay outside what is compiled in the forward to avoid recompilations as it's being done in generate
  • Allow external mask creation: In case someone passes their custom attention implementation, they may need their own mask creation function, which is now supported
  • TGI/vLLM backend should be even more efficient now, as we don't waste compute on creating a useless mask (would previously create a 4d mask as for sdpa, which would not be used)

@github-actions github-actions bot marked this pull request as draft April 29, 2025 14:26
Copy link
Contributor

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the Ready for review button (at the bottom of the PR page). This will assign reviewers and trigger CI.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@Cyrilvallez Cyrilvallez changed the title Refactor mask [core] Completely rewrite the masking logic for all attentions May 8, 2025
@Cyrilvallez Cyrilvallez force-pushed the refactor-mask branch 3 times, most recently from 53ca556 to ce42aa7 Compare May 12, 2025 07:49
@Cyrilvallez Cyrilvallez marked this pull request as ready for review May 12, 2025 16:25
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks very vey nice!
One thing I want to consider is to rather call the sliding, causal and chuncking directly in the modeling.
For example:

  • llama only need causal_mask, under the hood the causal mask should do an and with sddpa or cflash or flex
  • gemma need sliding_causal: same
  • llama4 needs chuncked causal
    I want the modeling to call an explicit function, rather than the mega general one!

This would keep our philosophy, as we don't want too general stuff hapenning when not needed (ex: llama should never care about sliding in codepathes)

Also misssing doc about how to add a new func!

@ydshieh
Copy link
Collaborator

ydshieh commented May 15, 2025

Wow!!!!!!!! 🚀

This PR seems worth a manually full CI. Ping me when it's time you think this PR is ready to trigger CI.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Damn nice

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

REview for the core logic, IMO can be simplified! BUt the modeling part is absolutely perfect!
For the visualization, I'll see how we could just overwrite the repr without affecting other operations!

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a quick question on this refactor: If I understand the code correctly, then the focus is currently on causal masks only, correct?

Would be nice to add a non-causal alternative which should only use a padding mask and expand respectively to the q_len and kv_len. That's more food for thought :D I dont want to make this PR even harder than it is.

@Cyrilvallez
Copy link
Member Author

For now it's mostly on causal masks because they are the one we need, but the idea is that it can be extended super easily from a set of mask primitives!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mega nice!
TODO before merging:

  1. move the causal_mask_mapping to a class attribute!
  2. show example of how to register a new function, but minimal (without sdpa correction for example)
  3. Make sure full graph training is not broken maybe? or at least fa2 training

That should be i

@Cyrilvallez Cyrilvallez changed the title [core] Completely rewrite the masking logic for all attentions 🚨🚨[core] Completely rewrite the masking logic for all attentions May 20, 2025
@Cyrilvallez Cyrilvallez force-pushed the refactor-mask branch 2 times, most recently from dee568c to 4a2e906 Compare May 21, 2025 11:30
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's go!

Comment on lines +141 to +144
print("I just entered the attention mask computation")
return sdpa_mask(*args, **kwargs)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's rather show how to do something like the paligemma or document masking here, something relevant!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those are a bit different, it's modifying the mask pattern vs adding a new mask format for the attention itself (both are complementary)

@Cyrilvallez Cyrilvallez merged commit 163138a into main May 22, 2025
21 checks passed
@Cyrilvallez Cyrilvallez deleted the refactor-mask branch May 22, 2025 09:38
@BenjaminBossan
Copy link
Member

Hi @Cyrilvallez I noticed that after this PR, calling prepare_inputs_for_generation can return an attention_mask that is a dict instead of a tensor. Is this expected? If yes, I need to update PEFT.

Reproducer:

from transformers import AutoModelForCausalLM
from transformers.cache_utils import HybridCache
import torch

model_id = 'hf-internal-testing/tiny-random-Gemma3ForCausalLM'
model = AutoModelForCausalLM.from_pretrained(model_id)

inputs = torch.arange(6).view(2, 3)
attention_mask = torch.ones_like(inputs)
# cache is required, w/o cache a tensor is returned as expected
cache = HybridCache(model.config, max_batch_size=2, max_cache_len=3)

model_kwargs = model.prepare_inputs_for_generation(
    inputs, attention_mask=attention_mask, past_key_values=cache, cache_position=torch.arange(3)
)
mask = model_kwargs['attention_mask']
assert isinstance(mask, torch.Tensor), f"expected attention mask to be tensor, got {mask}"

Before the PR (f8630c778c9220defecf1e3026d3438108b0baba), this passes. After the PR (163138a911c1fb4451ec4b32edaee20918a59def), it fails with:

AssertionError: expected attention mask to be tensor, got {'sliding_attention': tensor([[[[ True, False, False],
          [ True,  True, False],
          [ True,  True,  True]]],


        [[[ True, False, False],
          [ True,  True, False],
          [ True,  True,  True]]]])}

Comment on lines +60 to +66
if not hasattr(model.config, "layer_types"):
# If `layer_types` is not specified explicitly in the config, there is only 1 type of layers, so
# export will use `StaticCache` by default.
logging.info("Using `StaticCache` for export as `layer_types` is not specified in the config.")
self.model = TorchExportableModuleWithStaticCache(model)
else:
if model.config.cache_implementation == "hybrid":
self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len)
else:
raise ValueError(
f"Unsupported cache implementation: {model.config.cache_implementation}. "
"Please use `hybrid` or `static`."
)
self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Cyrilvallez What is layer_types? I'm concerning whether changes here are backwards compatible. For existing models on Hub like google/gemma-3-1b, it doesn't seem to come with the layer_types so it will fallback to the static cache which doesn't look correct.

@guangy10 guangy10 mentioned this pull request Jun 6, 2025
3 tasks
@kimishpatel
Copy link

One comment i have is that the way mask calculation is incorporated in most models is that the calculation of mask happens at model level. e.g. here https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma3/modeling_gemma3.py#L565-L566, however, different cache implementations may imply different attention masks. Different layers may have different cache impl, for example some layers can have sliding window of different size, others may use attention sink to keep say first few or some tokens. I feel the best way for the custom mask is at the attention layer so that the said layer can pass in all the information, including kv cache, to the custom mask function (e.g. layer_index).

@Cyrilvallez
Copy link
Member Author

Hey, sorry all I was on vacations!

@BenjaminBossan indeed, this is expected. Models using different types of attention in different layers (i.e. gemma3) will now have a dict returned by prepare_inputd_for_generation (one dict entry per attention type). Sorry this broke your tests! 😬

@guangy10 if you look at the configs, e.g. here, you'll see that the attribute was added in a BC manner for all models that were refactored! Let me know if you notice any issue though!

@kimishpatel In transformers, Cache are not at the layer level, so as of now only some configurations are acceptable (though I've had in mind to change that for some time, to make it more modular). And computing the mask at the AttentionLayer-level is not only redundant (most layers will create the same mask, wasting precious time), but it breaks compile completely, as we cannot pre-compute the masks anymore. For now, there are no known models with sliding windows of different sizes for different layers, so we decided to make it as simple as possible. This was taken into account when doing this refactor though, no worries, we definitely thought about it to scale easily in the future should this scenario happen

BenjaminBossan added a commit to BenjaminBossan/peft that referenced this pull request Jun 10, 2025
Resolves CI errors such as this one:

https://github.com/huggingface/peft/actions/runs/15481482956/job/43588020111#step:5:53182

After resolving that error, other errors can occur, but they're
unrelated and investigated independently.

After the transformers change in
huggingface/transformers#37866, it can happen
that:

> Models using different types of attention in different layers (i.e.
gemma3) will now have a dict returned by
prepare_inputd_for_generation (one dict entry per attention type)

As PEFT operates on the attention mask for prompt learning methods, we
need to adjust the code for the possibility of attention_mask being a
dict. Right now, I simply extract the single value if the dict is just
one element. For other sizes, I just raise an error, as I don't know how
to deal with that. For our tests, this is enough but we might need to
find a better solution in the future.
redmoe-moutain pushed a commit to redmoe-moutain/transformers that referenced this pull request Jun 10, 2025
…gingface#37866)

* start

* start having a clean 4d mask primitive

* Update mask_utils.py

* Update mask_utils.py

* switch name

* Update masking_utils.py

* add a new AttentionMask tensor class

* fix import

* nits

* fixes

* use full and quandrants

* general sdpa mask for all caches

* style

* start some tests

* tests with sliding, chunked

* add styling

* test hybrid

* Update masking_utils.py

* small temp fixes

* Update modeling_gemma2.py

* compile compatible

* Update masking_utils.py

* improve

* start making it more general

* Update masking_utils.py

* generate

* make it work with flex style primitives!

* Update masking_utils.py

* Update masking_utils.py

* Update masking_utils.py

* improve

* Update cache_utils.py

* Update masking_utils.py

* simplify - starting to look good!

* Update masking_utils.py

* name

* Update masking_utils.py

* style

* Update masking_utils.py

* Update masking_utils.py

* Update masking_utils.py

* Update masking_utils.py

* small fix for flex

* flex compile

* FA2

* Update masking_utils.py

* Escape for TGI/vLLM!

* Update masking_utils.py

* Update masking_utils.py

* Update masking_utils.py

* General case without cache

* rename

* full test on llama4

* small fix for FA2 guard with chunk

* Update modeling_gemma2.py

* post rebase cleanup

* FA2 supports static cache!

* Update modeling_flash_attention_utils.py

* Update flex_attention.py

* Update masking_utils.py

* Update masking_utils.py

* Update utils.py

* override for export

* Update executorch.py

* Update executorch.py

* Update executorch.py

* Update executorch.py

* Update masking_utils.py

* Update masking_utils.py

* output attentions

* style

* Update masking_utils.py

* Update executorch.py

* Add doicstring

* Add license and put mask visualizer at the end

* Update test_modeling_common.py

* fix broken test

* Update test_modeling_gemma.py

* Update test_modeling_gemma2.py

* Use fullgraph=False with FA2

* Update utils.py

* change name

* Update masking_utils.py

* improve doc

* change name

* Update modeling_attn_mask_utils.py

* more explicit logic based on model's property

* pattern in config

* extend

* fixes

* make it better

* generalize to other test models

* fix

* Update masking_utils.py

* fix

* do not check mask equivalence if layer types are different

* executorch

* Update modeling_gemma2.py

* Update masking_utils.py

* use layer_idx instead

* adjust

* Update masking_utils.py

* test

* fix imports

* Update modeling_gemma2.py

* other test models

* Update modeling_llama4.py

* Update masking_utils.py

* improve

* simplify

* Update masking_utils.py

* typos

* typo

* fix

* Update masking_utils.py

* default DynamicCache

* remove default cache

* simplify

* Update masking_utils.py

* Update masking_utils.py

* Update masking_utils.py

* Update masking_utils.py

* simplify

* Update masking_utils.py

* Update masking_utils.py

* Update masking_utils.py

* export

* Update executorch.py

* Update executorch.py

* Update flex_attention.py

* Update executorch.py

* upstream to modular gemma 1 & 2

* Update modular_mistral.py

* switch names

* use dict

* put it in the Layer directly

* update copy model source for mask functions

* apply so many modular (hopefully 1 shot)

* use explicite dicts for make style happy

* protect import

* check docstring

* better default in hybrid caches

* qwens

* Update modular_qwen2.py

* simplify core logic!

* Update executorch.py

* qwen3 moe

* Update masking_utils.py

* Update masking_utils.py

* simplify a lot sdpa causal skip

* Update masking_utils.py

* post-rebase

* gemma3 finally

* style

* check it before

* gemma3

* More general with newer torch

* align gemma3

* Update utils.py

* Update utils.py

* Update masking_utils.py

* Update test_modeling_common.py

* Update flex_attention.py

* Update flex_attention.py

* Update flex_attention.py

* test

* executorch

* Update test_modeling_common.py

* Update masking_utils.py

* Update masking_utils.py

* Update masking_utils.py

* Update masking_utils.py

* Update executorch.py

* Update test_modeling_common.py

* fix copies

* device

* sdpa can be used without mask -> pass the torchscript tests in this case

* Use enum for check

* revert enum and add check instead

* remove broken test

* cohere2

* some doc & reorganize the Interface

* Update tensor_parallel.py

* Update tensor_parallel.py

* doc and dummy

* Update test_modeling_paligemma2.py

* Update modeling_falcon_h1.py

* Update masking_utils.py

* executorch patch

* style

* CIs

* use register in executorch

* final comments!

---------

Co-authored-by: Arthur Zucker <[email protected]>
@kimishpatel
Copy link

kimishpatel commented Jun 10, 2025

And computing the mask at the AttentionLayer-level is not only redundant (most layers will create the same mask, wasting precious time),

To be fair, I doubt attention mask calculation has that much impact on performance for the most models. I have implemented ring buffer based kv cache, that needs a very different way of calculating mask and that mask calculation, while redundant, happens at attention layer. I have not observed any significant amount of time spent in there.

Although I think for block mask in flex attention, you might be right. That one is non-trivial.

Cache are not at the layer level,

how so? cache's update functions accept layer_idx, so they do have to know what layer the update belongs to.

I do understand though that transformers is not exactly providing building blocks for model authoring so from that perspective composability and modularity has limited value i suppose

@guangy10
Copy link
Contributor

guangy10 commented Jun 10, 2025

@guangy10 if you look at the configs, e.g. here, you'll see that the attribute was added in a BC manner for all models that were refactored! Let me know if you notice any issue though!

@Cyrilvallez Let's follow up in #38646

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants