-
Notifications
You must be signed in to change notification settings - Fork 29.4k
Fix bugs in DynamicCache #37880
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Fix bugs in DynamicCache #37880
Conversation
Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the |
cc @gante |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the PR 🤗 In general, LGTM (I'm not super keen in increasing the complexity in DynamicCache
, but I understand the importance of the fix)
Missing: update docstring with the new optional arg
src/transformers/cache_utils.py
Outdated
@@ -359,11 +359,15 @@ class DynamicCache(Cache): | |||
``` | |||
""" | |||
|
|||
def __init__(self, _distributed_cache_data: Optional[Iterable] = None) -> None: | |||
def __init__(self, _distributed_cache_data: Optional[Iterable] = None, num_layers: Optional[int] = None) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's accept config
instead of num_layers
(=config.num_layers
). It's more consistent with the other caches, which also take config
in __init__
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks! Not sure we need a new argument here!
src/transformers/cache_utils.py
Outdated
self.key_cache = [torch.tensor([]) for _ in range(num_layers)] | ||
self.value_cache = [torch.tensor([]) for _ in range(num_layers)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why don't we always init like this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to know how many layers we want to do this for.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DynamicCache
has lazy tensor init, and export needs eager tensor init :D
It's similar to the issue we have with TP (should be lazy) vs torch.compile (should be eager) in the hybrid caches
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One more detail and it's good for me 👍
src/transformers/cache_utils.py
Outdated
@@ -359,11 +359,17 @@ class DynamicCache(Cache): | |||
``` | |||
""" | |||
|
|||
def __init__(self, _distributed_cache_data: Optional[Iterable] = None) -> None: | |||
def __init__( | |||
self, _distributed_cache_data: Optional[Iterable] = None, config: Optional[PretrainedConfig] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing: docs for config
in the docstring above, explaining when it should be used (torch.export)
(sorry, I missed this detail in the previous review :D)
d5ee85f
to
fabfd80
Compare
I originally hoped to make DynamicCache torch.export compatible with dynamic shapes. But this seems quite difficult and seems outside of scope for export since the caching code is not really the model's forward pass. To make it work,
Both of the above will make transformers code quite ugly. And in export, we are working on exporting submodules with different input specs, so i don't feel it is that important to make DynamicShapes fully seamless with export at the cost of code complexity. Our current suggestion would be to get two graphs:
This PR still fixes the bug where we weren't able to run the exported artifact when dynamic shapes are used. cc: @xadupre @zhxchen17 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey! I think we can probably make the test a bit cleaner, then let's go! 🤗🚀
tests/utils/test_cache_utils.py
Outdated
|
||
def test_dynamic_cache_exportability_dynamic_cache(self): | ||
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it an extension of test_dynamic_cache_exportability
, or a new test that should be independent? If an extension, let's simply add the new parts to the existing test, otherwise let's have a better name for this new test! 🤗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @tugsbayasgalan! The new test you added does not pass (see the CI report below the PR), so it would need to be fixed before merging!
tests/utils/test_cache_utils.py
Outdated
@slow | ||
@require_read_token |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should not need these decorators, does it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nah it was just copy pasta. Deleted
We just need to fix the small conflict based on our new ruff rules, then it's good to go! |
0d89988
to
54dc95a
Compare
What does this PR do?
When we flatten DynamicCache for export, we never end up flattening the inner tensors of DynamicCache because when we start, there are 0 tensors initialized. As a result, we didn't correctly test the ep.module()(*args, **kwargs) behaviour when we do export when cache is populated.
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.