Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Gradient Accumulation fix across all models + trainer fully in forward() #34283

Merged
merged 15 commits into from
Oct 23, 2024
3 changes: 2 additions & 1 deletion src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -1114,6 +1114,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think making it an argument loss_args rather than kwargs would be better?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are explicitly kwargs, not positional so kwargs is more accurate

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn’t referring to the variable name. My suggestion is to use loss_kwargs instead of **loss_kwargs. The current design prevents forward functions from accepting additional keyword arguments, which could be inconvenient.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We haven't accepted that at all before

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is prone to evolve either way, with #33932 as an example!

) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -1172,7 +1173,7 @@ def forward(

loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,6 +1026,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -1083,7 +1084,7 @@ def forward(

loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/gemma/modular_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,6 +960,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
```python
Expand Down Expand Up @@ -1002,7 +1003,7 @@ def forward(

loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,6 +998,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -1064,7 +1065,7 @@ def forward(

loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/gemma2/modular_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
```python
Expand Down Expand Up @@ -805,7 +806,7 @@ def forward(

loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/glm/modeling_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,6 +1014,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -1071,7 +1072,7 @@ def forward(

loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/jamba/modeling_jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -1450,6 +1450,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: Optional[Union[int, None]] = None,
**loss_kwargs,
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -1515,7 +1516,7 @@ def forward(

loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)

aux_loss = None
if output_router_logits:
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -1238,6 +1238,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -1301,7 +1302,7 @@ def forward(

loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)

aux_loss = None
if output_router_logits:
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/mllama/modeling_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1887,6 +1887,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -1949,7 +1950,7 @@ def forward(

loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/nemotron/modeling_nemotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,6 +1026,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -1083,7 +1084,7 @@ def forward(

loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/olmo/modeling_olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,6 +1068,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -1126,7 +1127,7 @@ def forward(

loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/olmoe/modeling_olmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1228,6 +1228,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -1290,7 +1291,7 @@ def forward(

loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)

aux_loss = None
if output_router_logits:
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/phi/modeling_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1192,6 +1192,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -1250,7 +1251,7 @@ def forward(

loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/phi3/modeling_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1209,6 +1209,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -1275,7 +1276,7 @@ def forward(

loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/phimoe/modeling_phimoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1377,6 +1377,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -1442,7 +1443,7 @@ def forward(

loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)

aux_loss = None
if output_router_logits:
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/qwen2/modeling_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1120,6 +1120,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -1178,7 +1179,7 @@ def forward(

loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1304,6 +1304,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -1366,7 +1367,7 @@ def forward(

loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)

aux_loss = None
if output_router_logits:
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/rt_detr/modeling_rt_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2027,6 +2027,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**loss_kwargs,
) -> Union[Tuple[torch.FloatTensor], RTDetrObjectDetectionOutput]:
r"""
labels (`List[Dict]` of len `(batch_size,)`, *optional*):
Expand Down Expand Up @@ -2128,6 +2129,7 @@ def forward(
enc_topk_logits=enc_topk_logits,
enc_topk_bboxes=enc_topk_bboxes,
denoising_meta_values=denoising_meta_values,
**loss_kwargs,
)

if not return_dict:
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/zamba/modeling_zamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -1418,6 +1418,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -1477,7 +1478,7 @@ def forward(

loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size)
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down
38 changes: 34 additions & 4 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,16 @@ def __init__(
self.model_wrapped = model
self.model = model

# Just in case the model was wrapped outside of the `Trainer`
unwrapped_model = self.accelerator.unwrap_model(model)
model_forward = (
unwrapped_model.forward
if not _is_peft_model(unwrapped_model)
else unwrapped_model.get_base_model().forward
)

self.model_accepts_loss_kwargs = "loss_kwargs" in inspect.signature(model_forward).parameters

self.neftune_noise_alpha = args.neftune_noise_alpha

self.compute_metrics = compute_metrics
Expand Down Expand Up @@ -2455,7 +2465,7 @@ def _inner_training_loop(
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)

with self.accelerator.accumulate(model):
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
tr_loss_step = self.training_step(model, inputs)

if (
args.logging_nan_inf_filter
Expand All @@ -2477,6 +2487,24 @@ def _inner_training_loop(
steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch
)

performing_optimizer_step = (
is_last_step_and_steps_less_than_grad_acc
or (total_batched_samples) % args.gradient_accumulation_steps == 0
)

# During DDP we always multiply gradients by data_parallel_size / sample size since
# DDP normalizes by the number of data parallel workers
numerator = (
self.accelerator.state.num_processes
if not performing_optimizer_step
and self.accelerator.state.distributed_type == DistributedType.MULTI_GPU
else 1
)

# Only valid in accelerate >= 1.1.0
if hasattr(self.optimizer, "multiply_grads"):
self.optimizer.multiply_grads(numerator / (num_items_in_batch or 1.0))

if (
(total_batched_samples) % args.gradient_accumulation_steps == 0
or
Expand Down Expand Up @@ -3595,7 +3623,6 @@ def training_step(
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss *= self.args.gradient_accumulation_steps
self.accelerator.backward(loss, **kwargs)

return loss.detach() / self.args.gradient_accumulation_steps
Expand All @@ -3610,8 +3637,11 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
labels = inputs.pop("labels")
else:
labels = None
# if num_items_in_batch is not None:
# inputs["num_items_in_batch"] = num_items_in_batch
if self.model_accepts_loss_kwargs:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This if condition doesn't seem to work for PeftModel class (it only has kwargs not loss_kwargs 🫠 )

I tried just changing that condition to if True and ran some tests, and the loss calculation worked perfectly for a LORA on a Llama 3 1B.

I'm wondering if there's a safe/non-breaking way to support peft models here as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the flag, that may be why I couldn't recreate #34263 (comment)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@man-shar can you try giving peft another go? Should have fixed it

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can confirm self.model_accepts_loss_kwargs is the correct value now for a peft model!

But seems like the compute_loss function inside trainer.py isn't getting the num_items_in_batch argument passed to it from training_step.

I notice that argument was removed in 4f3f86d and the commit still says "Experimental". So I assume it will be reverted once you're done experimenting! It should work after that!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes indeed :)

loss_kwargs = {}
if num_items_in_batch is not None:
loss_kwargs["num_items_in_batch"] = num_items_in_batch
inputs = {**inputs, **loss_kwargs}
outputs = model(**inputs)
# Save past state if it exists
# TODO: this needs to be fixed and made cleaner later.
Expand Down
Loading