-
Notifications
You must be signed in to change notification settings - Fork 28.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable Gradient Accumulation fix across all models + trainer fully in forward() #34283
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, I don't see Llama being modified, that's probably because it now has FlashAttentionKwargs
type dict as kwargs. We can create ExtraKwargs, a nested dict with both flash kwargs and loss kwargs and default loss kwargs can be type dict?
🤗
@ArthurZucker |
@@ -3610,8 +3612,11 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N | |||
labels = inputs.pop("labels") | |||
else: | |||
labels = None | |||
# if num_items_in_batch is not None: | |||
# inputs["num_items_in_batch"] = num_items_in_batch | |||
if self.model_accepts_loss_kwargs: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This if condition doesn't seem to work for PeftModel
class (it only has kwargs
not loss_kwargs
🫠 )
I tried just changing that condition to if True
and ran some tests, and the loss calculation worked perfectly for a LORA on a Llama 3 1B.
I'm wondering if there's a safe/non-breaking way to support peft models here as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the flag, that may be why I couldn't recreate #34263 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@man-shar can you try giving peft another go? Should have fixed it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can confirm self.model_accepts_loss_kwargs
is the correct value now for a peft model!
But seems like the compute_loss
function inside trainer.py isn't getting the num_items_in_batch
argument passed to it from training_step
.
I notice that argument was removed in 4f3f86d and the commit still says "Experimental". So I assume it will be reverted once you're done experimenting! It should work after that!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes indeed :)
Awesome work by all of you on this. Insane dev speed over the past few days 🙏 🔥 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I only reviewed the PEFT-related code in trainer.py
and it LGTM. Thanks Zach.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Rebasing should probably fix quality! |
c2a8453
to
fc6d674
Compare
@@ -1114,6 +1114,7 @@ def forward( | |||
return_dict: Optional[bool] = None, | |||
cache_position: Optional[torch.LongTensor] = None, | |||
num_logits_to_keep: int = 0, | |||
**loss_kwargs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think making it an argument loss_args
rather than kwargs would be better?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They are explicitly kwargs
, not positional so kwargs is more accurate
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wasn’t referring to the variable name. My suggestion is to use loss_kwargs instead of **loss_kwargs. The current design prevents forward functions from accepting additional keyword arguments, which could be inconvenient.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We haven't accepted that at all before
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is prone to evolve either way, with #33932 as an example!
I wonder whether encoder models need to be fixed as well. @ArthurZucker @muellerzr |
Yep! Fill free to open a PR 🤗 |
… forward() (huggingface#34283) * Enable grad accum fix across all models + trainer fully in forward() * handle peft case * Account for DDP: need to run scale tests * Use accelerator state * Quality * Guard * Experiment w/ only fairseq fix * Fairseq only * Revert multiply_grads fix * Mult by grad accum to fully bring back solution * Style * Good to go now * Skip fx tests for now * Bookmark * Working now
What does this PR do?
Since most users still want OOTB, this trickles the loss kwargs to the rest of the models so that causal loss can be calculated properly
Fixes # (issue)
Fully fixes #34263 / finishes #34191 & #34198 & fixes #34242
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@ArthurZucker