Skip to content

Make cache traceable #35873

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 20, 2025
Merged

Make cache traceable #35873

merged 3 commits into from
Feb 20, 2025

Conversation

IlyasMoutawwakil
Copy link
Member

@IlyasMoutawwakil IlyasMoutawwakil commented Jan 24, 2025

What does this PR do?

In #35792 I came to the conclusion that tensor subclassing is a process that can only be achieved currently with some restriction (e.g. get_seq_length() needs to return a tensor), adding more developer cognitive load when adding a cache class. In this PR we make the cache traceable and exportable by not being a Module and registering cache tensors as buffers directly in TorchExportableModuleWithStaticCache.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

)
for i in range(len(self.static_cache.key_cache)):
self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why is non-persistent preferred in your opinion? Probably doesn't matter too much for inference as it will always start with filling the cache with prompt tokens even if they are persistent

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

non-persistent buffers are not saved with the model's state_dict, I think in the case of a big model with long sequence length static cache, the cache tensors will have nonnegligible memory footprint when exporting+saving the model

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@IlyasMoutawwakil
Copy link
Member Author

Tests in optimum-executorch are passing as well huggingface/optimum-executorch#4

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SUper cool, IMO just missing what this enables (as in, documentation about how to use this now ) might be in optimum directly?

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I remember we had a need for nn.Module: copy. We need to be able to copy the cache object (or clone) for prefix re-usage

@ArthurZucker
Copy link
Collaborator

@IlyasMoutawwakil
Copy link
Member Author

Also I remember we had a need for nn.Module: copy. We need to be able to copy the cache object (or clone) for prefix re-usage

There is a test for that I saw running locally https://github.com/huggingface/transformers/blob/main/tests/utils/test_cache_utils.py#L593C1-L622C60

Copy link
Contributor

@guangy10 guangy10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for validating the changes in Optimum huggingface/optimum-executorch#4, especially considering the inconvenience caused by the disruptions due to the migration to a new repository

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you for double-cheking the changes in many places @IlyasMoutawwakil 🤗

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's go if this test ran locally! 🚀

@tugsbayasgalan
Copy link
Contributor

Hey guys, i work on the torch.export team at PyTorch and i just wanted to verify that this change only works for Static kv cache right?

@IlyasMoutawwakil
Copy link
Member Author

Hey guys, i work on the torch.export team at PyTorch and i just wanted to verify that this change only works for Static kv cache right?

Hey ! not sure I understand, the change applies to Cache class which is inherited by all cache classes. And specifically for torch.export it removes the need to have Cache a subclass of torch.nn.Module.

@IlyasMoutawwakil
Copy link
Member Author

just updated the branch, will merge tonight after tests pass to be sure everything is good.

@eljandoubi
Copy link
Contributor

@IlyasMoutawwakil @ArthurZucker @gante @guangy10 detach Cahce from torch.nn.Module remove the .float() method from it which causes an error when calling convert_to_fp32
image

@IlyasMoutawwakil
Copy link
Member Author

self.float() should only be called if self is tensor or module, can you explain why it's called in this case ?

@gante
Copy link
Member

gante commented Mar 19, 2025

I saw the same stack trace in another issue, but the user didn't share a script to reproduce it. @eljandoubi could you kindly share a script to reproduce the issue? 🙏

@eljandoubi
Copy link
Contributor

@IlyasMoutawwakil the _is_fp16_bf16_tensor function checks if an object is tensor or has dtype attribute and that dtype is either fp16 or bf16.

@eljandoubi
Copy link
Contributor

eljandoubi commented Mar 20, 2025

@gante I'm afraid that I can't but I can provide guidance.When fine-tuning PaliGemma 2 mix in mixed precision (bf16), the evaluation in the training loop of the HF Trainer is done in fp32 and so convert_to_fp32 is called.

@IlyasMoutawwakil
Copy link
Member Author

This sounds like it was a silent error, because calling .float() on a Cache instance (when it was a nn.Module), doesn't actually do anything, since Cache.key_cache / Cache.value_cache are just lists and not module parameters.

In the case of fine-tuning I believe you need to pass use_cache=False so that Cache is not returned (and thus converting it is not attempted). We can implement .float() method that actually does the conversion of Cache.key_cache / Cache.value_cache depending on whether we want to silence this issue or force user to set use_cache=False when training, wdyt @gante ?

@eljandoubi
Copy link
Contributor

eljandoubi commented Mar 20, 2025

@IlyasMoutawwakil Thanks for enlightening me. I think set use_cache=False automatically in the Trainer like when using gradient checkpointing would be great. Plus ignore objects that have no .float() method in _is_fp16_bf16_tensor.

@gante
Copy link
Member

gante commented Mar 28, 2025

(see #37044)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants