-
Notifications
You must be signed in to change notification settings - Fork 28.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable Gradient Accumulation fix across all models + trainer fully in forward() #34283
Changes from 8 commits
0bfce6e
8c12bf4
058fe34
fc6d674
0aeb5ac
49b29d2
4f3f86d
58ee680
2d58b30
921abb8
4417984
21ca9a4
9967afc
98cbf7c
4e2328d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -582,6 +582,16 @@ def __init__( | |
self.model_wrapped = model | ||
self.model = model | ||
|
||
# Just in case the model was wrapped outside of the `Trainer` | ||
unwrapped_model = self.accelerator.unwrap_model(model) | ||
model_forward = ( | ||
unwrapped_model.forward | ||
if not _is_peft_model(unwrapped_model) | ||
else unwrapped_model.get_base_model().forward | ||
) | ||
|
||
self.model_accepts_loss_kwargs = "loss_kwargs" in inspect.signature(model_forward).parameters | ||
|
||
self.neftune_noise_alpha = args.neftune_noise_alpha | ||
|
||
self.compute_metrics = compute_metrics | ||
|
@@ -2455,7 +2465,7 @@ def _inner_training_loop( | |
self.control = self.callback_handler.on_step_begin(args, self.state, self.control) | ||
|
||
with self.accelerator.accumulate(model): | ||
tr_loss_step = self.training_step(model, inputs, num_items_in_batch) | ||
tr_loss_step = self.training_step(model, inputs) | ||
muellerzr marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if ( | ||
args.logging_nan_inf_filter | ||
|
@@ -2477,6 +2487,24 @@ def _inner_training_loop( | |
steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch | ||
) | ||
|
||
performing_optimizer_step = ( | ||
is_last_step_and_steps_less_than_grad_acc | ||
or (total_batched_samples) % args.gradient_accumulation_steps == 0 | ||
) | ||
|
||
# During DDP we always multiply gradients by data_parallel_size / sample size since | ||
# DDP normalizes by the number of data parallel workers | ||
numerator = ( | ||
self.accelerator.state.num_processes | ||
if not performing_optimizer_step | ||
and self.accelerator.state.distributed_type == DistributedType.MULTI_GPU | ||
else 1 | ||
) | ||
|
||
# Only valid in accelerate >= 1.1.0 | ||
if hasattr(self.optimizer, "multiply_grads"): | ||
self.optimizer.multiply_grads(numerator / (num_items_in_batch or 1.0)) | ||
|
||
if ( | ||
(total_batched_samples) % args.gradient_accumulation_steps == 0 | ||
or | ||
|
@@ -3595,7 +3623,6 @@ def training_step( | |
with amp.scale_loss(loss, self.optimizer) as scaled_loss: | ||
scaled_loss.backward() | ||
else: | ||
loss *= self.args.gradient_accumulation_steps | ||
self.accelerator.backward(loss, **kwargs) | ||
|
||
return loss.detach() / self.args.gradient_accumulation_steps | ||
|
@@ -3610,8 +3637,11 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N | |
labels = inputs.pop("labels") | ||
else: | ||
labels = None | ||
# if num_items_in_batch is not None: | ||
# inputs["num_items_in_batch"] = num_items_in_batch | ||
if self.model_accepts_loss_kwargs: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This if condition doesn't seem to work for I tried just changing that condition to I'm wondering if there's a safe/non-breaking way to support peft models here as well? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the flag, that may be why I couldn't recreate #34263 (comment) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @man-shar can you try giving peft another go? Should have fixed it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can confirm But seems like the I notice that argument was removed in 4f3f86d and the commit still says "Experimental". So I assume it will be reverted once you're done experimenting! It should work after that! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes indeed :) |
||
loss_kwargs = {} | ||
if num_items_in_batch is not None: | ||
loss_kwargs["num_items_in_batch"] = num_items_in_batch | ||
inputs = {**inputs, **loss_kwargs} | ||
outputs = model(**inputs) | ||
# Save past state if it exists | ||
# TODO: this needs to be fixed and made cleaner later. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think making it an argument
loss_args
rather than kwargs would be better?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They are explicitly
kwargs
, not positional so kwargs is more accurateThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wasn’t referring to the variable name. My suggestion is to use loss_kwargs instead of **loss_kwargs. The current design prevents forward functions from accepting additional keyword arguments, which could be inconvenient.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We haven't accepted that at all before
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is prone to evolve either way, with #33932 as an example!