-
Notifications
You must be signed in to change notification settings - Fork 26.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add TimmWrapper #34564
base: main
Are you sure you want to change the base?
Add TimmWrapper #34564
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
src/transformers/modeling_utils.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are 2 changes in this file:
- state_dict keys renaming moved into separate method to be able to override it for TimmWrapper (disable gamma/beta renaming + add prefix)
- metadata is None for timm checkpoints -> assuming these are pytorch checkpoints
default_image_processor_filename = ( | ||
"config.json" if is_timm_checkpoint(pretrained_model_name_or_path) else IMAGE_PROCESSOR_NAME | ||
) | ||
kwargs["image_processor_filename"] = kwargs.get("image_processor_filename", default_image_processor_filename) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
timm checkpoints store image processor config in config.json
@rwightman could you please make the first review in case you have bandwidth |
@qubvel I'm starting to work through it now, wanted to get eval working to check some familiar models and just wasted way too much time realizing I needed Annnyways, first pass of the code things looked sane but need to spend some time looking closer at the details and testing some cases. |
A few high level q...
|
@rwightman thanks for the review! Indeed there are some default params I'm also confused about, it's even more for object detection 🥲
Left it as it was in the previous PR, however,
The prefix "timm_model" is unique and is used in certain tests to identify when weights come from a timm model. It's also utilized in the
Originally, this was implemented without
I left a comment in the thread about this. We use transformers for weight loading to leverage features like device_map, torch_dtype, and quantization. I’m also unsure how to disable weight loading through transformers if it’s handled by timm, as I haven’t seen any examples of this in the repo. I can dig into it further if needed. Do you have an estimate of how many models involve weight renaming? Is there a way to update checkpoints without breaking older versions of timm? Alternatively, could we manage similar weight renaming directly in transformers? (Though I think this approach may be less robust.)
I’m not sure if it's currently compatible, but it would be great to enable it! The config should be compatible, however, the weights state dict will have the "model.timm_model." prefix. I can look into removing this prefix before saving. I will try to enable it and add a separate test for a few checkpoints. Thanks for bringing this up! |
I feel in this case there is a difference in alignment with other models, because both ImageClassification and base model wrap timm model instances that differ, instead of adding their own head to the same timm model. I see some options, depending on the mix of hidden state flags, head vs no head where it'd be appropriate to go through different forward calls, where more than just one argument might be appropriate to change on creation, etc. Also, thinking about the future and other tasks. I feel there is a high probability say for supporting native timm object detection more flexibility is desired so it'd be safer to have it uncoupled. That and there's a resisitance to significant changes in transformers after something is in there, so feel it's better to leave uncoupled to have additional flexibility and avoid being stuck with tricky decisions that might impact timm moving forward. |
Aside from renaming, it just doesn't work right now, you can't load weights for a model if the head size changes for a new classification task. The wrapper only works if you use the imagenet classifier. I realize timm doesn't have the dtype, lazy init features but it's better to have it work I feel. Can potentially look at supporting some of that through timm. It hasn't been a priority as there aren't too many very large models in timm. If there's no way to do use pretrained=True on creation then will probably need to figure out how to add & call a method in timm once transformers has the state_dict and before it's loaded into the model, not sure if there's a spot for such a call in transformers? |
Empty init weights function to ensure compatibility of the class in the library. | ||
""" | ||
if isinstance(module, (nn.Linear)): | ||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's not clear why this is here, what it's attempting to do. timm has model specific init fns though they aren't separately callable right now, doing something like this that could overwrite timm defaults would change model behaviour
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added this to initialize classifier, without this weights are not properly initialized, probably due to how the model is created in transformers
from transformers import TimmWrapperForImageClassification
# --------------
# With init
# --------------
model = TimmWrapperForImageClassification.from_pretrained("timm/resnet18.a1_in1k", num_labels=10, ignore_mismatched_sizes=True)
print(model.timm_model.fc.weight[:3, :3])
# tensor([[-0.2117, -0.2422, -0.2540],
# [-0.1106, -0.1856, -0.0152],
# [-0.3430, -0.6446, -0.0530]], grad_fn=<SliceBackward0>)
# --------------
# Without init
# --------------
# patch with empty init weight function to check
def empty_init(self, module):
pass
TimmWrapperForImageClassification._init_weights = empty_init
model = TimmWrapperForImageClassification.from_pretrained("timm/resnet18.a1_in1k", num_labels=10, ignore_mismatched_sizes=True)
print(model.timm_model.fc.weight[:3, :3])
# tensor([[0., 0., 0.],
# [0., 0., 0.],
# [0., 0., 0.]], grad_fn=<SliceBackward0>)
Ideally, we should get rid of this, but it's not common for transformers to load external models, so it might require more code changes. For now, its a simple fix to enable model loading with initialized classifier if shapes are mismatched.
|
||
if is_dir and os.path.exists(os.path.join(pretrained_model_name_or_path, IMAGE_PROCESSOR_NAME)): | ||
# timm models don't have a preprocessor_config.json file saved out | ||
return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't think absence of a file is a good check
Fixed! |
If there would be any timm model/class-specific function that fixes state_dict I suppose we can call it before loading weights, similar to |
Thinking about this a bit more, there is another issue with the hub / transformers first weight loading. Example this model https://huggingface.co/laion/CLIP-ViT-B-16-laion2B-s34B-b88K ... is an OpenCLIP first model, but timm can load it if you use the model name 'vit_base_patch16_clip_224.laion2b' Indeed I feel that the wrapper should support all timm model names that work in timm, but right now if the model isn't on the hub w/ a timm primary config it isn't useable. Ideally it should work with both a hub model name OR any timm model name. The timm model name would require timm do the pretrained loading to resolve any translation to other hub name or weight source. Some very popular examples of this |
Yay, I was able to run fine-tune using run_image_classification.py after this fix. An observation the output files for that script aren't directly useable to push to hub. There is no config matching timm format or name, there is a config output by the image preprocessor save process to a different filename that's a jumble of the timm config. Also, the state dict for the model has the timm_model prefix so it's not loadable in timm. Is there anyway to remove that prefix in the checkpoints? this would also make it more seamless doing local dir loads if someone had timm weights already and config files checked out. |
What does this PR do?
Adds a TimmWrapper set of classes such that timm models can be loaded in as transformer models into the library.
Continue of
General Usage
Pipeline
Timm models can now be used in the image classification (if a classification model) and image feature extraction pipelines
Trainer
Timm models can now be loaded and trained with the trainer class.
Example model trained with the trainer running the script command below:
https://huggingface.co/qubvel-hf/vit-base-beans
Other features enabled
output_hidden_states=True
oroutput_hidden_states=[1, 2, 3]
(to select specific hidden states)TODO
output_hidden_states
teststransformers
instead oftimm
, which architectures are affected?