Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 73 additions & 21 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2581,26 +2581,79 @@ def disable_input_require_grads(self):
"""
self._require_grads_hook.remove()

def get_encoder(self, modality: Optional[str] = None):
"""
Best-effort lookup of the *encoder* module. If provided with `modality` argument,
it looks for a modality-specific encoder in multimodal models (e.g. "image_encoder")
By default the function returns model's text encoder if any, and otherwise returns `self`.

Possible `modality` values are "image", "video" and "audio".
"""
# NOTE: new models need to use existing names for layers if possible, so this list doesn't grow infinitely
if modality in ["image", "video"]:
possible_module_names = ["vision_tower", "visual", "vision_model", "vision_encoder", "image_tower"]
if modality == "audio":
possible_module_names = ["audio_tower", "audio_encoder"]
elif modality is None:
possible_module_names = ["text_encoder", "encoder"]
else:
raise ValueError(f'Unnrecognized modality, has to be "image", "video" or "audio" but found {modality}')

for name in possible_module_names:
if hasattr(self, name):
return getattr(self, name)

if self.base_model is not self and hasattr(self.base_model, "get_encoder"):
return self.base_model.get_encoder()

# If this is a base transformer model (no encoder/model attributes), return self
return self

def set_encoder(self, encoder, modality: Optional[str] = None):
"""
Symmetric setter. Mirrors the lookup logic used in `get_encoder`.
"""

# NOTE: new models need to use existing names for layers if possible, so this list doesn't grow infinitely
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To note, this should be enforced in make fixup in code consistency part to save ourselves the hassle

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, isn't it going to be a huge limitation for contributors if we force it and auto-renam with fix-copies? Imo we need to communicate it when reviewing and explain why it's important. It's only a few ppl reviewing VLMs currently, so it might be easier

if modality in ["image", "video"]:
possible_module_names = ["vision_tower", "visual", "vision_model", "vision_encoder", "image_tower"]
if modality == "audio":
possible_module_names = ["audio_tower", "audio_encoder"]
elif modality is None:
possible_module_names = ["text_encoder", "encoder"]
else:
raise ValueError(f'Unnrecognized modality, has to be "image", "video" or "audio" but found {modality}')

for name in possible_module_names:
if hasattr(self, name):
print(name)
setattr(self, name, encoder)
return

if self.base_model is not self:
if hasattr(self.base_model, "set_encoder"):
self.base_model.set_encoder(encoder)
else:
self.model = encoder

def get_decoder(self):
"""
Best-effort lookup of the *decoder* module.

Order of attempts (covers ~85 % of current usages):

1. `self.decoder`
2. `self.model` (many wrappers store the decoder here)
3. `self.model.get_decoder()` (nested wrappers)
1. `self.decoder/self.language_model/self.text_model`
2. `self.base_model` (many wrappers store the decoder here)
3. `self.base_model.get_decoder()` (nested wrappers)
4. fallback: raise for the few exotic models that need a bespoke rule
"""
if hasattr(self, "decoder"):
return self.decoder
possible_module_names = ["language_model", "text_model", "decoder", "text_decoder"]
for name in possible_module_names:
if hasattr(self, name):
return getattr(self, name)

if hasattr(self, "model"):
inner = self.model
# See: https://github.com/huggingface/transformers/issues/40815
if hasattr(inner, "get_decoder") and type(inner) is not type(self):
return inner.get_decoder()
return inner
if self.base_model is not self and hasattr(self.base_model, "get_decoder"):
return self.base_model.get_decoder()

# If this is a base transformer model (no decoder/model attributes), return self
# This handles cases like MistralModel which is itself the decoder
Expand All @@ -2611,19 +2664,18 @@ def set_decoder(self, decoder):
Symmetric setter. Mirrors the lookup logic used in `get_decoder`.
"""

if hasattr(self, "decoder"):
self.decoder = decoder
return
possible_module_names = ["language_model", "text_model", "decoder"]
for name in possible_module_names:
if hasattr(self, name):
print(name)
setattr(self, name, decoder)
return

if hasattr(self, "model"):
inner = self.model
if hasattr(inner, "set_decoder"):
inner.set_decoder(decoder)
if self.base_model is not self:
if hasattr(self.base_model, "set_decoder"):
self.base_model.set_decoder(decoder)
else:
self.model = decoder
return

return

def _init_weights(self, module):
"""
Expand Down
25 changes: 0 additions & 25 deletions src/transformers/models/aria/modeling_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,12 +905,6 @@ def get_input_embeddings(self):
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)

def set_decoder(self, decoder):
self.language_model = decoder

def get_decoder(self):
return self.language_model

def get_image_features(
self,
pixel_values: torch.FloatTensor,
Expand Down Expand Up @@ -1070,12 +1064,6 @@ def set_input_embeddings(self, value):
def get_output_embeddings(self) -> nn.Module:
return self.lm_head

def set_decoder(self, decoder):
self.model.set_decoder(decoder)

def get_decoder(self):
return self.model.get_decoder()

def get_image_features(
self,
pixel_values: torch.FloatTensor,
Expand All @@ -1088,19 +1076,6 @@ def get_image_features(
vision_feature_layer=vision_feature_layer,
)

# Make modules available through conditional class for BC
@property
def language_model(self):
return self.model.language_model

@property
def vision_tower(self):
return self.model.vision_tower

@property
def multi_modal_projector(self):
return self.model.multi_modal_projector

@can_return_tuple
@auto_docstring
def forward(
Expand Down
9 changes: 0 additions & 9 deletions src/transformers/models/autoformer/modeling_autoformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1339,9 +1339,6 @@ def create_network_inputs(
)
return reshaped_lagged_sequence, features, loc, scale, static_feat

def get_encoder(self):
return self.encoder

@auto_docstring
def forward(
self,
Expand Down Expand Up @@ -1585,12 +1582,6 @@ def __init__(self, config: AutoformerConfig):
def output_params(self, decoder_output):
return self.parameter_projection(decoder_output[:, -self.config.prediction_length :, :])

def get_encoder(self):
return self.model.get_encoder()

def get_decoder(self):
return self.model.get_decoder()

@torch.jit.ignore
def output_distribution(self, params, loc=None, scale=None, trailing_n=None) -> torch.distributions.Distribution:
sliced_params = params
Expand Down
25 changes: 0 additions & 25 deletions src/transformers/models/aya_vision/modeling_aya_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,12 +179,6 @@ def get_input_embeddings(self):
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)

def set_decoder(self, decoder):
self.language_model = decoder

def get_decoder(self):
return self.language_model

def get_image_features(
self,
pixel_values: torch.FloatTensor,
Expand Down Expand Up @@ -355,12 +349,6 @@ def set_input_embeddings(self, value):
def get_output_embeddings(self) -> nn.Module:
return self.lm_head

def set_decoder(self, decoder):
self.model.set_decoder(decoder)

def get_decoder(self):
return self.model.get_decoder()

def get_image_features(
self,
pixel_values: torch.FloatTensor,
Expand All @@ -375,19 +363,6 @@ def get_image_features(
**kwargs,
)

# Make modules available through conditional class for BC
@property
def language_model(self):
return self.model.language_model

@property
def vision_tower(self):
return self.model.vision_tower

@property
def multi_modal_projector(self):
return self.model.multi_modal_projector

@can_return_tuple
@auto_docstring
def forward(
Expand Down
15 changes: 0 additions & 15 deletions src/transformers/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,9 +934,6 @@ def set_input_embeddings(self, value):
self.encoder.embed_tokens = self.shared
self.decoder.embed_tokens = self.shared

def get_encoder(self):
return self.encoder

@auto_docstring
def forward(
self,
Expand Down Expand Up @@ -1064,12 +1061,6 @@ def __init__(self, config: BartConfig):
# Initialize weights and apply final processing
self.post_init()

def get_encoder(self):
return self.model.get_encoder()

def get_decoder(self):
return self.model.get_decoder()

def resize_token_embeddings(
self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
) -> nn.Embedding:
Expand Down Expand Up @@ -1532,12 +1523,6 @@ def get_input_embeddings(self):
def set_input_embeddings(self, value):
self.model.decoder.embed_tokens = value

def set_decoder(self, decoder):
self.model.decoder = decoder

def get_decoder(self):
return self.model.decoder

@auto_docstring
def forward(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2105,9 +2105,6 @@ def _tie_weights(self):
self._tie_embedding_weights(self.encoder.embed_tokens, self.shared)
self._tie_embedding_weights(self.decoder.embed_tokens, self.shared)

def get_encoder(self):
return self.encoder

@auto_docstring
def forward(
self,
Expand Down Expand Up @@ -2225,12 +2222,6 @@ def __init__(self, config: BigBirdPegasusConfig):
# Initialize weights and apply final processing
self.post_init()

def get_encoder(self):
return self.model.get_encoder()

def get_decoder(self):
return self.model.get_decoder()

def resize_token_embeddings(
self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
) -> nn.Embedding:
Expand Down Expand Up @@ -2640,12 +2631,6 @@ def get_input_embeddings(self):
def set_input_embeddings(self, value):
self.model.decoder.embed_tokens = value

def set_decoder(self, decoder):
self.model.decoder = decoder

def get_decoder(self):
return self.model.decoder

@auto_docstring
def forward(
self,
Expand Down
15 changes: 0 additions & 15 deletions src/transformers/models/blenderbot/modeling_blenderbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,9 +887,6 @@ def set_input_embeddings(self, value):
self.encoder.embed_tokens = self.shared
self.decoder.embed_tokens = self.shared

def get_encoder(self):
return self.encoder

@auto_docstring
def forward(
self,
Expand Down Expand Up @@ -1025,12 +1022,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)

def get_encoder(self):
return self.model.get_encoder()

def get_decoder(self):
return self.model.get_decoder()

def resize_token_embeddings(
self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
) -> nn.Embedding:
Expand Down Expand Up @@ -1203,12 +1194,6 @@ def get_input_embeddings(self):
def set_input_embeddings(self, value):
self.model.decoder.embed_tokens = value

def set_decoder(self, decoder):
self.model.decoder = decoder

def get_decoder(self):
return self.model.decoder

@auto_docstring
def forward(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -860,9 +860,6 @@ def set_input_embeddings(self, value):
self.encoder.embed_tokens = self.shared
self.decoder.embed_tokens = self.shared

def get_encoder(self):
return self.encoder

@auto_docstring
def forward(
self,
Expand Down Expand Up @@ -985,12 +982,6 @@ def __init__(self, config: BlenderbotSmallConfig):
# Initialize weights and apply final processing
self.post_init()

def get_encoder(self):
return self.model.get_encoder()

def get_decoder(self):
return self.model.get_decoder()

def resize_token_embeddings(
self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
) -> nn.Embedding:
Expand Down Expand Up @@ -1163,12 +1154,6 @@ def get_input_embeddings(self):
def set_input_embeddings(self, value):
self.model.decoder.embed_tokens = value

def set_decoder(self, decoder):
self.model.decoder = decoder

def get_decoder(self):
return self.model.decoder

@auto_docstring
def forward(
self,
Expand Down
6 changes: 0 additions & 6 deletions src/transformers/models/blip_2/modeling_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,9 +1073,6 @@ def get_output_embeddings(self) -> nn.Module:
def get_encoder(self):
return self.language_model.get_encoder()

def get_decoder(self):
return self.language_model.get_decoder()

def _tie_weights(self):
if not self.config.use_decoder_only_language_model:
self.language_model.encoder.embed_tokens = self.language_model.shared
Expand Down Expand Up @@ -1636,9 +1633,6 @@ def get_output_embeddings(self) -> nn.Module:
def get_encoder(self):
return self.language_model.get_encoder()

def get_decoder(self):
return self.language_model.get_decoder()

def _tie_weights(self):
if not self.config.use_decoder_only_language_model:
self.language_model.encoder.embed_tokens = self.language_model.shared
Expand Down
Loading