-
Notifications
You must be signed in to change notification settings - Fork 31.2k
🚨 Generalize get_decoder() for multimodal and delete redundant code 🔪
#42156
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 6 commits
304d3be
37415a9
0fd4474
2efe34b
f3bfd28
bf4bebd
73bdc27
07af770
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2581,26 +2581,79 @@ def disable_input_require_grads(self): | |
| """ | ||
| self._require_grads_hook.remove() | ||
|
|
||
| def get_encoder(self, modality: Optional[str] = None): | ||
| """ | ||
| Best-effort lookup of the *encoder* module. If provided with `modality` argument, | ||
| it looks for a modality-specific encoder in multimodal models (e.g. "image_encoder") | ||
| By default the function returns model's text encoder if any, and otherwise returns `self`. | ||
|
|
||
| Possible `modality` values are "image", "video" and "audio". | ||
| """ | ||
| # NOTE: new models need to use existing names for layers if possible, so this list doesn't grow infinitely | ||
| if modality in ["image", "video"]: | ||
| possible_module_names = ["vision_tower", "visual", "vision_model", "vision_encoder", "image_tower"] | ||
| if modality == "audio": | ||
| possible_module_names = ["audio_tower", "audio_encoder"] | ||
| elif modality is None: | ||
| possible_module_names = ["text_encoder", "encoder"] | ||
| else: | ||
| raise ValueError(f'Unnrecognized modality, has to be "image", "video" or "audio" but found {modality}') | ||
|
|
||
| for name in possible_module_names: | ||
| if hasattr(self, name): | ||
| return getattr(self, name) | ||
|
|
||
| if self.base_model is not self and hasattr(self.base_model, "get_encoder"): | ||
| return self.base_model.get_encoder() | ||
|
|
||
| # If this is a base transformer model (no encoder/model attributes), return self | ||
| return self | ||
|
|
||
| def set_encoder(self, encoder, modality: Optional[str] = None): | ||
| """ | ||
| Symmetric setter. Mirrors the lookup logic used in `get_encoder`. | ||
| """ | ||
|
|
||
| # NOTE: new models need to use existing names for layers if possible, so this list doesn't grow infinitely | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To note, this should be enforced in
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm, isn't it going to be a huge limitation for contributors if we force it and auto-renam with fix-copies? Imo we need to communicate it when reviewing and explain why it's important. It's only a few ppl reviewing VLMs currently, so it might be easier |
||
| if modality in ["image", "video"]: | ||
| possible_module_names = ["vision_tower", "visual", "vision_model", "vision_encoder", "image_tower"] | ||
| if modality == "audio": | ||
| possible_module_names = ["audio_tower", "audio_encoder"] | ||
| elif modality is None: | ||
| possible_module_names = ["text_encoder", "encoder"] | ||
| else: | ||
| raise ValueError(f'Unnrecognized modality, has to be "image", "video" or "audio" but found {modality}') | ||
|
|
||
| for name in possible_module_names: | ||
| if hasattr(self, name): | ||
| print(name) | ||
zucchini-nlp marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| setattr(self, name, encoder) | ||
| return | ||
|
|
||
| if self.base_model is not self: | ||
| if hasattr(self.base_model, "set_encoder"): | ||
| self.base_model.set_encoder(encoder) | ||
| else: | ||
| self.model = encoder | ||
|
|
||
| def get_decoder(self): | ||
| """ | ||
| Best-effort lookup of the *decoder* module. | ||
|
|
||
| Order of attempts (covers ~85 % of current usages): | ||
|
|
||
| 1. `self.decoder` | ||
| 2. `self.model` (many wrappers store the decoder here) | ||
| 3. `self.model.get_decoder()` (nested wrappers) | ||
| 1. `self.decoder/self.language_model/self.text_model` | ||
| 2. `self.base_model` (many wrappers store the decoder here) | ||
| 3. `self.base_model.get_decoder()` (nested wrappers) | ||
| 4. fallback: raise for the few exotic models that need a bespoke rule | ||
| """ | ||
| if hasattr(self, "decoder"): | ||
| return self.decoder | ||
| possible_module_names = ["language_model", "text_model", "decoder", "text_decoder"] | ||
| for name in possible_module_names: | ||
| if hasattr(self, name): | ||
| return getattr(self, name) | ||
|
|
||
| if hasattr(self, "model"): | ||
| inner = self.model | ||
| # See: https://github.com/huggingface/transformers/issues/40815 | ||
| if hasattr(inner, "get_decoder") and type(inner) is not type(self): | ||
| return inner.get_decoder() | ||
| return inner | ||
| if self.base_model is not self and hasattr(self.base_model, "get_decoder"): | ||
| return self.base_model.get_decoder() | ||
|
|
||
| # If this is a base transformer model (no decoder/model attributes), return self | ||
| # This handles cases like MistralModel which is itself the decoder | ||
|
|
@@ -2611,19 +2664,18 @@ def set_decoder(self, decoder): | |
| Symmetric setter. Mirrors the lookup logic used in `get_decoder`. | ||
| """ | ||
|
|
||
| if hasattr(self, "decoder"): | ||
| self.decoder = decoder | ||
| return | ||
| possible_module_names = ["language_model", "text_model", "decoder"] | ||
| for name in possible_module_names: | ||
| if hasattr(self, name): | ||
| print(name) | ||
| setattr(self, name, decoder) | ||
| return | ||
|
|
||
| if hasattr(self, "model"): | ||
| inner = self.model | ||
| if hasattr(inner, "set_decoder"): | ||
| inner.set_decoder(decoder) | ||
| if self.base_model is not self: | ||
| if hasattr(self.base_model, "set_decoder"): | ||
| self.base_model.set_decoder(decoder) | ||
| else: | ||
| self.model = decoder | ||
| return | ||
|
|
||
| return | ||
|
|
||
| def _init_weights(self, module): | ||
| """ | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.