-
Notifications
You must be signed in to change notification settings - Fork 29.8k
Make cache traceable #35873
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make cache traceable #35873
Conversation
) | ||
for i in range(len(self.static_cache.key_cache)): | ||
self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious why is non-persistent preferred in your opinion? Probably doesn't matter too much for inference as it will always start with filling the cache with prompt tokens even if they are persistent
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
non-persistent buffers are not saved with the model's state_dict
, I think in the case of a big model with long sequence length static cache, the cache tensors will have nonnegligible memory footprint when exporting+saving the model
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Tests in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SUper cool, IMO just missing what this enables (as in, documentation about how to use this now ) might be in optimum directly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also I remember we had a need for nn.Module: copy. We need to be able to copy the cache object (or clone) for prefix re-usage
Can you check that https://github.com/huggingface/huggingface-llama-recipes/blob/main/performance_optimization/prompt_reuse.py still works |
There is a test for that I saw running locally https://github.com/huggingface/transformers/blob/main/tests/utils/test_cache_utils.py#L593C1-L622C60 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for validating the changes in Optimum huggingface/optimum-executorch#4, especially considering the inconvenience caused by the disruptions due to the migration to a new repository
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you for double-cheking the changes in many places @IlyasMoutawwakil 🤗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's go if this test ran locally! 🚀
Hey guys, i work on the torch.export team at PyTorch and i just wanted to verify that this change only works for Static kv cache right? |
Hey ! not sure I understand, the change applies to |
just updated the branch, will merge tonight after tests pass to be sure everything is good. |
@IlyasMoutawwakil @ArthurZucker @gante @guangy10 detach Cahce from torch.nn.Module remove the .float() method from it which causes an error when calling convert_to_fp32 |
self.float() should only be called if self is tensor or module, can you explain why it's called in this case ? |
I saw the same stack trace in another issue, but the user didn't share a script to reproduce it. @eljandoubi could you kindly share a script to reproduce the issue? 🙏 |
@IlyasMoutawwakil the |
@gante I'm afraid that I can't but I can provide guidance.When fine-tuning PaliGemma 2 mix in mixed precision (bf16), the evaluation in the training loop of the HF |
This sounds like it was a silent error, because calling In the case of fine-tuning I believe you need to pass |
@IlyasMoutawwakil Thanks for enlightening me. I think set |
(see #37044) |
What does this PR do?
In #35792 I came to the conclusion that tensor subclassing is a process that can only be achieved currently with some restriction (e.g.
get_seq_length()
needs to return a tensor), adding more developer cognitive load when adding a cache class. In this PR we make the cache traceable and exportable by not being aModule
and registering cache tensors as buffers directly inTorchExportableModuleWithStaticCache
.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.