Skip to content

Commit d9f7336

Browse files
authored
Enable Gradient Accumulation fix across all models + trainer fully in forward() (#34283)
* Enable grad accum fix across all models + trainer fully in forward() * handle peft case * Account for DDP: need to run scale tests * Use accelerator state * Quality * Guard * Experiment w/ only fairseq fix * Fairseq only * Revert multiply_grads fix * Mult by grad accum to fully bring back solution * Style * Good to go now * Skip fx tests for now * Bookmark * Working now
1 parent 1fb575f commit d9f7336

25 files changed

+81
-31
lines changed

src/transformers/models/cohere/modeling_cohere.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1114,6 +1114,7 @@ def forward(
11141114
return_dict: Optional[bool] = None,
11151115
cache_position: Optional[torch.LongTensor] = None,
11161116
num_logits_to_keep: int = 0,
1117+
**loss_kwargs,
11171118
) -> Union[Tuple, CausalLMOutputWithPast]:
11181119
r"""
11191120
Args:
@@ -1172,7 +1173,7 @@ def forward(
11721173

11731174
loss = None
11741175
if labels is not None:
1175-
loss = self.loss_function(logits, labels, self.vocab_size)
1176+
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
11761177

11771178
if not return_dict:
11781179
output = (logits,) + outputs[1:]

src/transformers/models/gemma/modeling_gemma.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1030,6 +1030,7 @@ def forward(
10301030
return_dict: Optional[bool] = None,
10311031
cache_position: Optional[torch.LongTensor] = None,
10321032
num_logits_to_keep: int = 0,
1033+
**loss_kwargs,
10331034
) -> Union[Tuple, CausalLMOutputWithPast]:
10341035
r"""
10351036
Args:
@@ -1087,7 +1088,7 @@ def forward(
10871088

10881089
loss = None
10891090
if labels is not None:
1090-
loss = self.loss_function(logits, labels, self.vocab_size)
1091+
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
10911092

10921093
if not return_dict:
10931094
output = (logits,) + outputs[1:]

src/transformers/models/gemma/modular_gemma.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -961,6 +961,7 @@ def forward(
961961
return_dict: Optional[bool] = None,
962962
cache_position: Optional[torch.LongTensor] = None,
963963
num_logits_to_keep: int = 0,
964+
**loss_kwargs,
964965
) -> Union[Tuple, CausalLMOutputWithPast]:
965966
r"""
966967
```python
@@ -1003,7 +1004,7 @@ def forward(
10031004

10041005
loss = None
10051006
if labels is not None:
1006-
loss = self.loss_function(logits, labels, self.vocab_size)
1007+
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
10071008

10081009
if not return_dict:
10091010
output = (logits,) + outputs[1:]

src/transformers/models/gemma2/modeling_gemma2.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1002,6 +1002,7 @@ def forward(
10021002
return_dict: Optional[bool] = None,
10031003
cache_position: Optional[torch.LongTensor] = None,
10041004
num_logits_to_keep: int = 0,
1005+
**loss_kwargs,
10051006
) -> Union[Tuple, CausalLMOutputWithPast]:
10061007
r"""
10071008
Args:
@@ -1068,7 +1069,7 @@ def forward(
10681069

10691070
loss = None
10701071
if labels is not None:
1071-
loss = self.loss_function(logits, labels, self.vocab_size)
1072+
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
10721073

10731074
if not return_dict:
10741075
output = (logits,) + outputs[1:]

src/transformers/models/gemma2/modular_gemma2.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -756,6 +756,7 @@ def forward(
756756
return_dict: Optional[bool] = None,
757757
cache_position: Optional[torch.LongTensor] = None,
758758
num_logits_to_keep: int = 0,
759+
**loss_kwargs,
759760
) -> Union[Tuple, CausalLMOutputWithPast]:
760761
r"""
761762
```python
@@ -807,7 +808,7 @@ def forward(
807808

808809
loss = None
809810
if labels is not None:
810-
loss = self.loss_function(logits, labels, self.vocab_size)
811+
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
811812

812813
if not return_dict:
813814
output = (logits,) + outputs[1:]

src/transformers/models/glm/modeling_glm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1014,6 +1014,7 @@ def forward(
10141014
return_dict: Optional[bool] = None,
10151015
cache_position: Optional[torch.LongTensor] = None,
10161016
num_logits_to_keep: int = 0,
1017+
**loss_kwargs,
10171018
) -> Union[Tuple, CausalLMOutputWithPast]:
10181019
r"""
10191020
Args:
@@ -1071,7 +1072,7 @@ def forward(
10711072

10721073
loss = None
10731074
if labels is not None:
1074-
loss = self.loss_function(logits, labels, self.vocab_size)
1075+
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
10751076

10761077
if not return_dict:
10771078
output = (logits,) + outputs[1:]

src/transformers/models/jamba/modeling_jamba.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1450,6 +1450,7 @@ def forward(
14501450
return_dict: Optional[bool] = None,
14511451
cache_position: Optional[torch.LongTensor] = None,
14521452
num_logits_to_keep: Optional[Union[int, None]] = None,
1453+
**loss_kwargs,
14531454
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
14541455
r"""
14551456
Args:
@@ -1515,7 +1516,7 @@ def forward(
15151516

15161517
loss = None
15171518
if labels is not None:
1518-
loss = self.loss_function(logits, labels, self.vocab_size)
1519+
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
15191520

15201521
aux_loss = None
15211522
if output_router_logits:

src/transformers/models/mixtral/modeling_mixtral.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1240,6 +1240,7 @@ def forward(
12401240
return_dict: Optional[bool] = None,
12411241
cache_position: Optional[torch.LongTensor] = None,
12421242
num_logits_to_keep: int = 0,
1243+
**loss_kwargs,
12431244
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
12441245
r"""
12451246
Args:
@@ -1303,7 +1304,7 @@ def forward(
13031304

13041305
loss = None
13051306
if labels is not None:
1306-
loss = self.loss_function(logits, labels, self.vocab_size)
1307+
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
13071308

13081309
aux_loss = None
13091310
if output_router_logits:

src/transformers/models/mllama/modeling_mllama.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1887,6 +1887,7 @@ def forward(
18871887
return_dict: Optional[bool] = None,
18881888
cache_position: Optional[torch.LongTensor] = None,
18891889
num_logits_to_keep: int = 0,
1890+
**loss_kwargs,
18901891
) -> Union[Tuple, CausalLMOutputWithPast]:
18911892
r"""
18921893
Args:
@@ -1949,7 +1950,7 @@ def forward(
19491950

19501951
loss = None
19511952
if labels is not None:
1952-
loss = self.loss_function(logits, labels, self.vocab_size)
1953+
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
19531954

19541955
if not return_dict:
19551956
output = (logits,) + outputs[1:]

src/transformers/models/nemotron/modeling_nemotron.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1028,6 +1028,7 @@ def forward(
10281028
return_dict: Optional[bool] = None,
10291029
cache_position: Optional[torch.LongTensor] = None,
10301030
num_logits_to_keep: int = 0,
1031+
**loss_kwargs,
10311032
) -> Union[Tuple, CausalLMOutputWithPast]:
10321033
r"""
10331034
Args:
@@ -1085,7 +1086,7 @@ def forward(
10851086

10861087
loss = None
10871088
if labels is not None:
1088-
loss = self.loss_function(logits, labels, self.vocab_size)
1089+
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
10891090

10901091
if not return_dict:
10911092
output = (logits,) + outputs[1:]

0 commit comments

Comments
 (0)