Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 0 additions & 69 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,19 +110,6 @@ class PreTrainedConfig(PushToHubMixin):
Whether or not the model should return a [`~transformers.utils.ModelOutput`] instead of a plain tuple.
is_encoder_decoder (`bool`, *optional*, defaults to `False`):
Whether the model is used as an encoder/decoder or not.
is_decoder (`bool`, *optional*, defaults to `False`):
Whether to only use the decoder in an encoder-decoder architecture, otherwise it has no effect on
decoder-only or encoder-only architectures.
cross_attention_hidden_size (`bool`, *optional*):
The hidden size of the cross-attention layer in case the model is used as a decoder in an encoder-decoder
setting and the cross-attention hidden dimension differs from `self.config.hidden_size`.
add_cross_attention (`bool`, *optional*, defaults to `False`):
Whether cross-attention layers should be added to the model. Note, this option is only relevant for models
that can be used as decoder models within the [`EncoderDecoderModel`] class, which consists of all models
in `AUTO_MODELS_FOR_CAUSAL_LM`.
tie_encoder_decoder (`bool`, *optional*, defaults to `False`):
Whether all encoder weights should be tied to their equivalent decoder weights. This requires the encoder
and decoder model to have the exact same parameter names.
chunk_size_feed_forward (`int`, *optional*, defaults to `0`):
The chunk size of all feed forward layers in the residual attention blocks. A chunk size of `0` means that
the feed forward layer is not chunked. A chunk size of n means that the feed forward layer processes `n` <
Expand All @@ -133,45 +120,20 @@ class PreTrainedConfig(PushToHubMixin):

architectures (`list[str]`, *optional*):
Model architectures that can be used with the model pretrained weights.
finetuning_task (`str`, *optional*):
Name of the task used to fine-tune the model.
id2label (`dict[int, str]`, *optional*):
A map from index (for instance prediction index, or target index) to label.
label2id (`dict[str, int]`, *optional*):
A map from label to index for the model.
num_labels (`int`, *optional*):
Number of labels to use in the last layer added to the model, typically for a classification task.
task_specific_params (`dict[str, Any]`, *optional*):
Additional keyword arguments to store for the current task.
problem_type (`str`, *optional*):
Problem type for `XxxForSequenceClassification` models. Can be one of `"regression"`,
`"single_label_classification"` or `"multi_label_classification"`.

> Parameters linked to the tokenizer

tokenizer_class (`str`, *optional*):
The name of the associated tokenizer class to use (if none is set, will use the tokenizer associated to the
model by default).
prefix (`str`, *optional*):
A specific prompt that should be added at the beginning of each text before calling the model.
bos_token_id (`int`, *optional*):
The id of the _beginning-of-stream_ token.
pad_token_id (`int`, *optional*):
The id of the _padding_ token.
eos_token_id (`int`, *optional*):
The id of the _end-of-stream_ token.
decoder_start_token_id (`int`, *optional*):
If an encoder-decoder model starts decoding with a different token than _bos_, the id of that token.
sep_token_id (`int`, *optional*):
The id of the _separation_ token.

> PyTorch specific parameters

torchscript (`bool`, *optional*, defaults to `False`):
Whether or not the model should be used with Torchscript.
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the
model has a output word embedding layer.
dtype (`str`, *optional*):
The `dtype` of the weights. This attribute can be used to initialize the model to a non-default `dtype`
(which is normally `float32`) and thus allow for optimal storage allocation. For example, if the saved
Expand Down Expand Up @@ -209,29 +171,14 @@ def __init__(
torchscript: bool = False,
dtype: Optional[Union[str, "torch.dtype"]] = None,
# Common arguments
tie_word_embeddings: bool = True,
chunk_size_feed_forward: int = 0,
is_encoder_decoder: bool = False,
is_decoder: bool = False,
cross_attention_hidden_size: Optional[int] = None,
add_cross_attention: bool = False,
tie_encoder_decoder: bool = False,
# Fine-tuning task arguments
architectures: Optional[list[str]] = None,
finetuning_task: Optional[str] = None,
id2label: Optional[dict[int, str]] = None,
label2id: Optional[dict[str, int]] = None,
num_labels: Optional[int] = None,
task_specific_params: Optional[dict[str, Any]] = None,
problem_type: Optional[str] = None,
# Tokenizer kwargs
tokenizer_class: Optional[str] = None,
prefix: Optional[str] = None,
bos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
sep_token_id: Optional[int] = None,
decoder_start_token_id: Optional[int] = None,
**kwargs,
):
# Validation for some arguments
Expand Down Expand Up @@ -273,22 +220,15 @@ def __init__(
self._output_attentions = output_attentions # has public property

# Less common kwargs, only used by some models
self.tie_word_embeddings = tie_word_embeddings
self.chunk_size_feed_forward = chunk_size_feed_forward

# Encoder-decoder models attributes
self.is_encoder_decoder = is_encoder_decoder
self.is_decoder = is_decoder # used in encoder-decoder models to differentiate encoder from decoder
self.cross_attention_hidden_size = cross_attention_hidden_size
self.add_cross_attention = add_cross_attention
self.tie_encoder_decoder = tie_encoder_decoder

# Fine-tuning task attributes
self.architectures = architectures
self.finetuning_task = finetuning_task
self.id2label = id2label
self.label2id = label2id
self.task_specific_params = task_specific_params
self.problem_type = problem_type

if self.id2label is None:
Expand All @@ -297,15 +237,6 @@ def __init__(
# Keys are always strings in JSON so convert ids to int here.
self.id2label = {int(key): value for key, value in self.id2label.items()}

# Tokenizer attributes
self.tokenizer_class = tokenizer_class
self.prefix = prefix
self.bos_token_id = bos_token_id
self.pad_token_id = pad_token_id
self.eos_token_id = eos_token_id
self.sep_token_id = sep_token_id
self.decoder_start_token_id = decoder_start_token_id

# Retrocompatibility: Parameters for sequence generation. While we will keep the ability to load these
# parameters, saving them will be deprecated. In a distant future, we won't need to load them.
for parameter_name, default_value in self._get_global_generation_defaults().items():
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/integrations/executorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1109,7 +1109,7 @@ def generate(self, prompt_token_ids, max_new_tokens):
decoder_input_ids = torch.tensor([[next_token]], dtype=torch.long, device=model_device)

# Check if EOS token
if next_token == self.config.eos_token_id:
if next_token == self.generation_config.eos_token_id:
break

return generated_ids
Expand Down
7 changes: 4 additions & 3 deletions src/transformers/modeling_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,14 @@ def forward(
else:
batch_size = inputs_embeds.shape[0]

if self.config.pad_token_id is None and batch_size != 1:
pad_token_id = getattr(self.config, "pad_token_id", None)
if pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
if pad_token_id is None:
last_non_pad_token = -1
elif input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
non_pad_mask = (input_ids != pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
else:
Expand Down
5 changes: 2 additions & 3 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3133,9 +3133,8 @@ def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None, mean
new_num_tokens = new_embeddings.weight.shape[0]

# if word embeddings are not tied, make sure that lm head is resized as well
if (
self.get_output_embeddings() is not None
and not self.config.get_text_config(decoder=True).tie_word_embeddings
if self.get_output_embeddings() is not None and not getattr(
self.config.get_text_config(decoder=True), "tie_word_embeddings", False
):
old_lm_head = self.get_output_embeddings()
if isinstance(old_lm_head, torch.nn.Embedding):
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/aimv2/configuration_aimv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,8 @@ def __init__(
initializer_range: bool = 0.02,
**kwargs,
):
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
super().__init__(**kwargs)
self.eos_token_id = eos_token_id

self.vocab_size = vocab_size
self.hidden_size = hidden_size
Expand Down
5 changes: 4 additions & 1 deletion src/transformers/models/albert/configuration_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,10 @@ def __init__(
eos_token_id=3,
**kwargs,
):
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
super().__init__(**kwargs)
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id

self.vocab_size = vocab_size
self.embedding_size = embedding_size
Expand Down
5 changes: 4 additions & 1 deletion src/transformers/models/altclip/configuration_altclip.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,10 @@ def __init__(
project_dim=768,
**kwargs,
):
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
super().__init__(**kwargs)
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id

self.vocab_size = vocab_size
self.hidden_size = hidden_size
Expand Down
12 changes: 5 additions & 7 deletions src/transformers/models/apertus/configuration_apertus.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,13 +203,11 @@ def __init__(
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)

super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
self.tie_word_embeddings = tie_word_embeddings
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
super().__init__(**kwargs)


__all__ = ["ApertusConfig"]
12 changes: 5 additions & 7 deletions src/transformers/models/arcee/configuration_arcee.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,13 +190,11 @@ def __init__(
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)

super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
self.tie_word_embeddings = tie_word_embeddings
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
super().__init__(**kwargs)


__all__ = ["ArceeConfig"]
12 changes: 5 additions & 7 deletions src/transformers/models/aria/configuration_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,13 +213,11 @@ def __init__(
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)

super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
self.tie_word_embeddings = tie_word_embeddings
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
super().__init__(**kwargs)


class AriaConfig(PreTrainedConfig):
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -1094,7 +1094,7 @@ def from_pretrained(
config = AutoConfig.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
)
config_tokenizer_class = config.tokenizer_class
config_tokenizer_class = getattr(config, "tokenizer_class", None)
if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map:
tokenizer_auto_map = config.auto_map["AutoTokenizer"]

Expand Down
12 changes: 5 additions & 7 deletions src/transformers/models/bamba/configuration_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,13 +191,11 @@ def __init__(
self.mamba_proj_bias = mamba_proj_bias
self.z_loss_coefficient = z_loss_coefficient

super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
self.tie_word_embeddings = tie_word_embeddings
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
super().__init__(**kwargs)

@property
def layers_block_type(self):
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/bark/configuration_bark.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,8 @@ def __init__(self, tie_word_embeddings=True, n_codes_total=8, n_codes_given=1, *
self.n_codes_total = n_codes_total
self.n_codes_given = n_codes_given

super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
self.tie_word_embeddings = tie_word_embeddings
super().__init__(**kwargs)


class BarkConfig(PreTrainedConfig):
Expand Down
14 changes: 10 additions & 4 deletions src/transformers/models/bart/configuration_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,14 @@ def __init__(
is_encoder_decoder=True,
decoder_start_token_id=2,
forced_eos_token_id=2,
is_decoder=False,
add_cross_attention=False,
tie_word_embeddings=True,
**kwargs,
):
self.is_decoder = is_decoder
self.add_cross_attention = add_cross_attention
self.tie_word_embeddings = tie_word_embeddings
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.d_model = d_model
Expand All @@ -160,13 +166,13 @@ def __init__(
self.num_hidden_layers = encoder_layers
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True

self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.decoder_start_token_id = decoder_start_token_id
super().__init__(
num_labels=num_labels,
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
is_encoder_decoder=is_encoder_decoder,
decoder_start_token_id=decoder_start_token_id,
forced_eos_token_id=forced_eos_token_id,
**kwargs,
)
Expand Down
7 changes: 6 additions & 1 deletion src/transformers/models/bert/configuration_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,14 @@ def __init__(
pad_token_id=0,
use_cache=True,
classifier_dropout=None,
is_decoder=False,
add_cross_attention=False,
**kwargs,
):
super().__init__(pad_token_id=pad_token_id, **kwargs)
super().__init__(**kwargs)
self.pad_token_id = pad_token_id
self.is_decoder = is_decoder
self.add_cross_attention = add_cross_attention

self.vocab_size = vocab_size
self.hidden_size = hidden_size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,17 @@ def __init__(
bos_token_id=2,
eos_token_id=1,
use_cache=True,
is_decoder=False,
add_cross_attention=False,
**kwargs,
):
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
super().__init__(**kwargs)
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id

self.is_decoder = is_decoder
self.add_cross_attention = add_cross_attention
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
Expand Down
18 changes: 10 additions & 8 deletions src/transformers/models/big_bird/configuration_big_bird.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,16 +126,18 @@ def __init__(
block_size=64,
num_random_blocks=3,
classifier_dropout=None,
is_decoder=False,
add_cross_attention=False,
**kwargs,
):
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
sep_token_id=sep_token_id,
**kwargs,
)

super().__init__(**kwargs)

self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.sep_token_id = sep_token_id
self.is_decoder = is_decoder
self.add_cross_attention = add_cross_attention
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
Expand Down
Loading