Skip to content
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
44301cc
quick fix
ArthurZucker Oct 16, 2024
1456088
3 losses
ArthurZucker Oct 16, 2024
e57f00c
oups
ArthurZucker Oct 16, 2024
7fa8503
fix
ArthurZucker Oct 16, 2024
b955ea5
nits
ArthurZucker Oct 16, 2024
07478e0
check how it scales for special models
ArthurZucker Oct 16, 2024
1b356ef
propagate for conditiona detr
ArthurZucker Oct 16, 2024
4ef45b0
propagate
ArthurZucker Oct 16, 2024
61da9b1
propagate
ArthurZucker Oct 16, 2024
2e3f0f7
propagate
ArthurZucker Oct 16, 2024
c31a3fb
fixes
ArthurZucker Oct 16, 2024
a8cd107
propagate changes
ArthurZucker Oct 16, 2024
711c357
update
ArthurZucker Oct 16, 2024
4888cf3
fixup
ArthurZucker Oct 16, 2024
4323d85
nits
ArthurZucker Oct 16, 2024
e5e4bbd
f string
ArthurZucker Oct 16, 2024
239a256
fixes
ArthurZucker Oct 16, 2024
bd298da
more fixes
ArthurZucker Oct 16, 2024
5dfc51c
?
ArthurZucker Oct 16, 2024
0a1cd2b
nit
ArthurZucker Oct 16, 2024
64f7e29
arg annoying f string
ArthurZucker Oct 16, 2024
aa01ae9
nits
ArthurZucker Oct 16, 2024
8c1d68a
grumble
ArthurZucker Oct 16, 2024
846cf1c
update
ArthurZucker Oct 16, 2024
e7e8a20
nit
ArthurZucker Oct 16, 2024
622290c
refactor
ArthurZucker Oct 16, 2024
91e28aa
fix fetch tests
ArthurZucker Oct 16, 2024
da649b9
nit
ArthurZucker Oct 16, 2024
df6472a
nit
ArthurZucker Oct 16, 2024
cf1eb7b
Update src/transformers/loss/loss_utils.py
ArthurZucker Oct 16, 2024
dafd11b
Merge branch 'quick-fix-ga' of github.com:huggingface/transformers in…
ArthurZucker Oct 16, 2024
30f27cd
update
ArthurZucker Oct 16, 2024
d0edfad
nit
ArthurZucker Oct 16, 2024
9bcecc3
fixup
ArthurZucker Oct 16, 2024
2839b3c
make pass
ArthurZucker Oct 16, 2024
557d225
nits
ArthurZucker Oct 16, 2024
393e178
port code to more models
ArthurZucker Oct 16, 2024
aac054d
fixup
ArthurZucker Oct 16, 2024
ce32d5e
ntis
ArthurZucker Oct 16, 2024
4dc49ac
arf
ArthurZucker Oct 16, 2024
d221e58
update
ArthurZucker Oct 16, 2024
f03b193
update
ArthurZucker Oct 16, 2024
22b6283
nits
ArthurZucker Oct 16, 2024
64829e3
update
ArthurZucker Oct 16, 2024
0b6f425
fix
ArthurZucker Oct 16, 2024
e6f6f52
update
ArthurZucker Oct 16, 2024
fa691aa
nits
ArthurZucker Oct 16, 2024
66f6eef
fine
ArthurZucker Oct 16, 2024
fcdf13d
agjkfslga.jsdlkgjklas
ArthurZucker Oct 16, 2024
ece5e01
nits
ArthurZucker Oct 17, 2024
bb236eb
fix fx?
ArthurZucker Oct 17, 2024
7c2b7ce
update
ArthurZucker Oct 17, 2024
0be4379
update
ArthurZucker Oct 17, 2024
36d76d7
styel
ArthurZucker Oct 17, 2024
92979e7
fix imports
ArthurZucker Oct 17, 2024
a55e440
update
ArthurZucker Oct 17, 2024
b14c3dd
update
ArthurZucker Oct 17, 2024
dbbc3ce
fixup to fix the torch fx?
ArthurZucker Oct 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,9 @@ class PretrainedConfig(PushToHubMixin):
Whether the model should use legacy TensorFlow losses. Legacy losses have variable output shapes and may
not be XLA-compatible. This option is here for backward compatibility and will be removed in Transformers
v5.
loss_type (`str`, *optional*):
The type of loss that the model should use. It should be in `LOSS_MAPPING`'s keys, otherwise the loss will
be automatically infered from the model architecture.
"""

model_type: str = ""
Expand Down
98 changes: 98 additions & 0 deletions src/transformers/loss_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import torch
import torch.nn as nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from .models.detr.loss_detr import ForObjectDetectionLoss, ForSegmentationLoss


def DefaultCrossEntropyLoss(logits, labels, **kwargs):
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()

# Flatten the tokens
shift_logits = shift_logits.view(-1, kwargs["vocab_size"])
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)

num_items = kwargs.pop("num_items", None)

if num_items is not None:
# Calculate the CrossEntropyLoss manually when using grad accum
log_probs = nn.functional.log_softmax(shift_logits, dim=-1)
loss = -log_probs[range(shift_labels.size(0)), shift_labels]
loss = loss.sum() / num_items
else:
loss = nn.functional.cross_entropy(shift_logits, shift_labels, ignore_index=-100)

return loss


def ForSequenceClassificationLoss(logits, labels, pooled_logits, **kwargs):
config = kwargs["config"]
num_labels = config.num_labels
if config.problem_type is None:
if num_labels == 1:
config.problem_type = "regression"
elif num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
config.problem_type = "single_label_classification"
else:
config.problem_type = "multi_label_classification"

if config.problem_type == "regression":
loss_fct = MSELoss()
if num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, num_labels), labels.view(-1))
elif config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
return loss


def ForQuestionAnsweringLoss(start_logits, end_logits, start_positions, end_positions):
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1).to(start_logits.device)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1).to(end_logits.device)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions = start_positions.clamp(0, ignored_index)
end_positions = end_positions.clamp(0, ignored_index)

loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
return total_loss


def ForTokenClassification(logits, labels, config, **kwargs):
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.view(-1, config.num_labels)
labels = labels.view(-1)
logits = logits.float()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
return loss_fct(logits, labels)


LOSS_MAPPING = {
"ForCausalLM": DefaultCrossEntropyLoss,
"ForQuestionAnswering": ForQuestionAnsweringLoss,
"ForSequenceClassification": ForSequenceClassificationLoss,
"ForTokenClassification": ForTokenClassification,
}

LOSS_MAPPING["ForSegmentation"] = ForSegmentationLoss
LOSS_MAPPING["ForObjectDetection"] = ForObjectDetectionLoss
29 changes: 28 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import warnings
from contextlib import contextmanager
from dataclasses import dataclass
from functools import partial, wraps
from functools import lru_cache, partial, wraps
from threading import Thread
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from zipfile import is_zipfile
Expand All @@ -45,6 +45,7 @@
from .dynamic_module_utils import custom_object_save
from .generation import GenerationConfig, GenerationMixin
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
from .loss_utils import LOSS_MAPPING
from .pytorch_utils import ( # noqa: F401
Conv1D,
apply_chunking_to_forward,
Expand Down Expand Up @@ -4979,6 +4980,32 @@ def _is_quantized_training_enabled(self):

return self.hf_quantizer.is_trainable

@property
@lru_cache
def loss_function(self):
if getattr(self.config, "loss_type", None) is not None:
loss_type = self.config.loss_type
else:
loss_type = self.__class__.__name__
if loss_type not in LOSS_MAPPING:
loss_groups = f"({'|'.join(LOSS_MAPPING)})"
loss_type = re.findall(loss_groups, self.__class__.__name__)
if len(loss_type) > 0:
loss_type = loss_type[0]
else:
loss_type = None
if loss_type is None:
raise ValueError(
"We could not determine which loss function to use."
f"based on the the class name. Make sure you add `{ self.__class__.__name__}` to the `LOSS_MAPPING`"
)
if loss_type not in LOSS_MAPPING and getattr(self.config, "loss_type", None) is not None:
raise ValueError(
f"`loss_type={loss_type}` was set in the config but it is unrecognised"
f"based on the the class name. Make sure you add `{loss_type}` to the `LOSS_MAPPING`"
)
return LOSS_MAPPING[loss_type]


PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub)
if PreTrainedModel.push_to_hub.__doc__ is not None:
Expand Down
Loading
Loading