-
Notifications
You must be signed in to change notification settings - Fork 30.9k
Fix Gradient Accumulation issue #34191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 22 commits
Commits
Show all changes
58 commits
Select commit
Hold shift + click to select a range
44301cc
quick fix
ArthurZucker 1456088
3 losses
ArthurZucker e57f00c
oups
ArthurZucker 7fa8503
fix
ArthurZucker b955ea5
nits
ArthurZucker 07478e0
check how it scales for special models
ArthurZucker 1b356ef
propagate for conditiona detr
ArthurZucker 4ef45b0
propagate
ArthurZucker 61da9b1
propagate
ArthurZucker 2e3f0f7
propagate
ArthurZucker c31a3fb
fixes
ArthurZucker a8cd107
propagate changes
ArthurZucker 711c357
update
ArthurZucker 4888cf3
fixup
ArthurZucker 4323d85
nits
ArthurZucker e5e4bbd
f string
ArthurZucker 239a256
fixes
ArthurZucker bd298da
more fixes
ArthurZucker 5dfc51c
?
ArthurZucker 0a1cd2b
nit
ArthurZucker 64f7e29
arg annoying f string
ArthurZucker aa01ae9
nits
ArthurZucker 8c1d68a
grumble
ArthurZucker 846cf1c
update
ArthurZucker e7e8a20
nit
ArthurZucker 622290c
refactor
ArthurZucker 91e28aa
fix fetch tests
ArthurZucker da649b9
nit
ArthurZucker df6472a
nit
ArthurZucker cf1eb7b
Update src/transformers/loss/loss_utils.py
ArthurZucker dafd11b
Merge branch 'quick-fix-ga' of github.com:huggingface/transformers in…
ArthurZucker 30f27cd
update
ArthurZucker d0edfad
nit
ArthurZucker 9bcecc3
fixup
ArthurZucker 2839b3c
make pass
ArthurZucker 557d225
nits
ArthurZucker 393e178
port code to more models
ArthurZucker aac054d
fixup
ArthurZucker ce32d5e
ntis
ArthurZucker 4dc49ac
arf
ArthurZucker d221e58
update
ArthurZucker f03b193
update
ArthurZucker 22b6283
nits
ArthurZucker 64829e3
update
ArthurZucker 0b6f425
fix
ArthurZucker e6f6f52
update
ArthurZucker fa691aa
nits
ArthurZucker 66f6eef
fine
ArthurZucker fcdf13d
agjkfslga.jsdlkgjklas
ArthurZucker ece5e01
nits
ArthurZucker bb236eb
fix fx?
ArthurZucker 7c2b7ce
update
ArthurZucker 0be4379
update
ArthurZucker 36d76d7
styel
ArthurZucker 92979e7
fix imports
ArthurZucker a55e440
update
ArthurZucker b14c3dd
update
ArthurZucker dbbc3ce
fixup to fix the torch fx?
ArthurZucker File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import torch | ||
import torch.nn as nn | ||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss | ||
|
||
from .models.detr.loss_detr import ForObjectDetectionLoss, ForSegmentationLoss | ||
|
||
|
||
def DefaultCrossEntropyLoss(logits, labels, **kwargs): | ||
# Upcast to float if we need to compute the loss to avoid potential precision issues | ||
logits = logits.float() | ||
# Shift so that tokens < n predict n | ||
shift_logits = logits[..., :-1, :].contiguous() | ||
shift_labels = labels[..., 1:].contiguous() | ||
|
||
# Flatten the tokens | ||
shift_logits = shift_logits.view(-1, kwargs["vocab_size"]) | ||
shift_labels = shift_labels.view(-1) | ||
# Enable model parallelism | ||
shift_labels = shift_labels.to(shift_logits.device) | ||
|
||
num_items = kwargs.pop("num_items", None) | ||
|
||
if num_items is not None: | ||
# Calculate the CrossEntropyLoss manually when using grad accum | ||
log_probs = nn.functional.log_softmax(shift_logits, dim=-1) | ||
loss = -log_probs[range(shift_labels.size(0)), shift_labels] | ||
loss = loss.sum() / num_items | ||
else: | ||
loss = nn.functional.cross_entropy(shift_logits, shift_labels, ignore_index=-100) | ||
|
||
return loss | ||
|
||
|
||
def ForSequenceClassificationLoss(logits, labels, pooled_logits, **kwargs): | ||
config = kwargs["config"] | ||
num_labels = config.num_labels | ||
if config.problem_type is None: | ||
if num_labels == 1: | ||
config.problem_type = "regression" | ||
elif num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): | ||
config.problem_type = "single_label_classification" | ||
else: | ||
config.problem_type = "multi_label_classification" | ||
|
||
if config.problem_type == "regression": | ||
loss_fct = MSELoss() | ||
if num_labels == 1: | ||
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) | ||
else: | ||
loss = loss_fct(pooled_logits, labels) | ||
elif config.problem_type == "single_label_classification": | ||
loss_fct = CrossEntropyLoss() | ||
loss = loss_fct(pooled_logits.view(-1, num_labels), labels.view(-1)) | ||
elif config.problem_type == "multi_label_classification": | ||
loss_fct = BCEWithLogitsLoss() | ||
loss = loss_fct(pooled_logits, labels) | ||
return loss | ||
|
||
|
||
def ForQuestionAnsweringLoss(start_logits, end_logits, start_positions, end_positions): | ||
total_loss = None | ||
if start_positions is not None and end_positions is not None: | ||
# If we are on multi-GPU, split add a dimension | ||
if len(start_positions.size()) > 1: | ||
start_positions = start_positions.squeeze(-1).to(start_logits.device) | ||
if len(end_positions.size()) > 1: | ||
end_positions = end_positions.squeeze(-1).to(end_logits.device) | ||
# sometimes the start/end positions are outside our model inputs, we ignore these terms | ||
ignored_index = start_logits.size(1) | ||
start_positions = start_positions.clamp(0, ignored_index) | ||
end_positions = end_positions.clamp(0, ignored_index) | ||
|
||
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) | ||
start_loss = loss_fct(start_logits, start_positions) | ||
end_loss = loss_fct(end_logits, end_positions) | ||
total_loss = (start_loss + end_loss) / 2 | ||
return total_loss | ||
|
||
|
||
def ForTokenClassification(logits, labels, config, **kwargs): | ||
# Upcast to float if we need to compute the loss to avoid potential precision issues | ||
logits = logits.view(-1, config.num_labels) | ||
labels = labels.view(-1) | ||
logits = logits.float() | ||
# Flatten the tokens | ||
loss_fct = CrossEntropyLoss() | ||
return loss_fct(logits, labels) | ||
|
||
|
||
LOSS_MAPPING = { | ||
"ForCausalLM": DefaultCrossEntropyLoss, | ||
ArthurZucker marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
"ForQuestionAnswering": ForQuestionAnsweringLoss, | ||
"ForSequenceClassification": ForSequenceClassificationLoss, | ||
"ForTokenClassification": ForTokenClassification, | ||
} | ||
|
||
LOSS_MAPPING["ForSegmentation"] = ForSegmentationLoss | ||
LOSS_MAPPING["ForObjectDetection"] = ForObjectDetectionLoss |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.